yym68686 commited on
Commit
923b378
·
1 Parent(s): 8ad5d0e

✨ Feature: 1. Add support for experimental frontend.

Browse files

2. Add support for configuration files without listing models.

.gitignore CHANGED
@@ -1,5 +1,5 @@
1
  api.json
2
- api.yaml
3
  .env
4
  __pycache__
5
  .vscode
 
1
  api.json
2
+ *.yaml
3
  .env
4
  __pycache__
5
  .vscode
components/provider_table.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xue import Div, Table, Thead, Tbody, Tr, Th, Td, Button, Input, Script, Head, Style, Span
2
+ from xue.components.checkbox import checkbox
3
+ from xue.components.dropdown import dropdown_menu, dropdown_menu_content
4
+ from xue.components.button import button
5
+ from xue.components.input import input
6
+
7
+ Head.add_default_children([
8
+ Style("""
9
+ .data-table-container {
10
+ width: 100%;
11
+ overflow-x: auto;
12
+ border: 1px solid #e2e8f0;
13
+ border-radius: 0.5rem;
14
+ overflow-x: visible !important;
15
+ }
16
+ .data-table {
17
+ width: 100%;
18
+ border-collapse: separate;
19
+ border-spacing: 0;
20
+ }
21
+ .data-table th, .data-table td {
22
+ padding: 0.75rem 1rem;
23
+ text-align: left;
24
+ border-bottom: 1px solid #e2e8f0;
25
+ }
26
+ .data-table th {
27
+ font-weight: 500;
28
+ font-size: 0.875rem;
29
+ color: #4b5563;
30
+ height: 2.5rem;
31
+ transition: background-color 0.2s;
32
+ }
33
+ .data-table thead tr:hover th,
34
+ .data-table tbody tr:hover {
35
+ background-color: #f8fafc;
36
+ }
37
+ .data-table tbody tr:last-child td {
38
+ border-bottom: none;
39
+ }
40
+ .sortable-header {
41
+ cursor: pointer;
42
+ user-select: none;
43
+ display: inline-flex;
44
+ align-items: center;
45
+ padding: 0.25rem 0.5rem;
46
+ border-radius: 0.25rem;
47
+ transition: background-color 0.2s;
48
+ }
49
+ .sortable-header:hover {
50
+ background-color: #e5e7eb;
51
+ }
52
+ .sort-icon {
53
+ display: inline-block;
54
+ width: 1rem;
55
+ height: 1rem;
56
+ margin-left: 0.25rem;
57
+ transition: transform 0.2s;
58
+ opacity: 0;
59
+ }
60
+ .sortable-header:hover .sort-icon,
61
+ .sort-asc .sort-icon,
62
+ .sort-desc .sort-icon {
63
+ opacity: 1;
64
+ }
65
+ .sort-asc .sort-icon {
66
+ transform: rotate(180deg);
67
+ }
68
+ .table-header {
69
+ display: flex;
70
+ justify-content: space-between;
71
+ align-items: center;
72
+ margin-bottom: 1rem;
73
+ }
74
+ .table-footer {
75
+ display: flex;
76
+ justify-content: space-between;
77
+ align-items: center;
78
+ margin-top: 1rem;
79
+ }
80
+ .pagination {
81
+ display: flex;
82
+ gap: 0.5rem;
83
+ }
84
+ @media (prefers-color-scheme: dark) {
85
+ .data-table-container {
86
+ border-color: #4b5563;
87
+ }
88
+ .data-table th, .data-table td {
89
+ border-color: #4b5563;
90
+ }
91
+ .data-table th {
92
+ color: #d1d5db;
93
+ }
94
+ .data-table thead tr:hover th,
95
+ .data-table tbody tr:hover {
96
+ background-color: #1f2937;
97
+ }
98
+ .sortable-header:hover {
99
+ background-color: #374151;
100
+ }
101
+ }
102
+ """, id="data-table-style"),
103
+ Script("""
104
+ function toggleAllRows(checked) {
105
+ const checkboxes = document.querySelectorAll('.row-checkbox');
106
+ checkboxes.forEach(cb => cb.checked = checked);
107
+ updateSelectedCount();
108
+ }
109
+
110
+ function updateSelectedCount() {
111
+ const selectedCount = document.querySelectorAll('.row-checkbox:checked').length;
112
+ const totalCount = document.querySelectorAll('.row-checkbox').length;
113
+ document.getElementById('selected-count').textContent = `${selectedCount} of ${totalCount} row(s) selected.`;
114
+ }
115
+
116
+ function sortTable(columnIndex, accessor) {
117
+ const table = document.querySelector('.data-table');
118
+ const header = table.querySelector(`th[data-accessor="${accessor}"]`);
119
+ const isAscending = !header.classList.contains('sort-asc');
120
+
121
+ // Update sort direction
122
+ table.querySelectorAll('th').forEach(th => th.classList.remove('sort-asc', 'sort-desc'));
123
+ header.classList.add(isAscending ? 'sort-asc' : 'sort-desc');
124
+
125
+ // Sort the table
126
+ const rows = Array.from(table.querySelectorAll('tbody tr'));
127
+ rows.sort((a, b) => {
128
+ const aValue = a.querySelector(`td[data-accessor="${accessor}"]`).textContent;
129
+ const bValue = b.querySelector(`td[data-accessor="${accessor}"]`).textContent;
130
+ return isAscending ? aValue.localeCompare(bValue) : bValue.localeCompare(aValue);
131
+ });
132
+
133
+ // Update the table
134
+ const tbody = table.querySelector('tbody');
135
+ rows.forEach(row => tbody.appendChild(row));
136
+ }
137
+
138
+ document.addEventListener('change', function(event) {
139
+ if (event.target.classList.contains('row-checkbox')) {
140
+ updateSelectedCount();
141
+ }
142
+ });
143
+ """, id="data-table-script"),
144
+ ])
145
+
146
+ def data_table(columns, data, id, with_filter=True):
147
+ return Div(
148
+ Div(
149
+ input(type="text", placeholder="Filter...", id=f"{id}-filter", class_="mr-auto"),
150
+ Div(
151
+ button(
152
+ "Add Provider",
153
+ variant="secondary",
154
+ hx_get="/add-provider-sheet",
155
+ hx_target="#sheet-container",
156
+ hx_swap="innerHTML",
157
+ class_="h-[2.625rem]"
158
+ ),
159
+ dropdown_menu("Columns"),
160
+ ),
161
+ class_="table-header flex items-center"
162
+ ) if with_filter else None,
163
+ Div(
164
+ Div(
165
+ Table(
166
+ Thead(
167
+ Tr(
168
+ Th(checkbox("select-all", "", onclick="toggleAllRows(this.checked)")),
169
+ *[Th(
170
+ Div(
171
+ col['label'],
172
+ Span("▼", class_="sort-icon"),
173
+ class_="sortable-header" if col.get('sortable', False) else "",
174
+ onclick=f"sortTable({i}, '{col['value']}')" if col.get('sortable', False) else None
175
+ ),
176
+ data_accessor=col['value']
177
+ ) for i, col in enumerate(columns)],
178
+ Th("Actions") # 新增的操作列
179
+ )
180
+ ),
181
+ Tbody(
182
+ *[Tr(
183
+ Td(checkbox(f"row-{i}", "", class_="row-checkbox")),
184
+ *[Td(row[col['value']], data_accessor=col['value']) for col in columns],
185
+ Td(row_actions_menu(i)), # 使用行索引作为 row_id
186
+ id=f"row-{i}"
187
+ ) for i, row in enumerate(data)]
188
+ ),
189
+ class_="data-table"
190
+ ),
191
+ class_="data-table-container"
192
+ ),
193
+ Div(
194
+ Div(id="selected-count", class_="text-sm text-gray-500"),
195
+ Div(
196
+ button("Previous", variant="outline", class_="mr-2"),
197
+ button("Next", variant="outline"),
198
+ class_="pagination"
199
+ ),
200
+ class_="table-footer"
201
+ ),
202
+ id=id
203
+ ),
204
+ )
205
+
206
+ def get_column_visibility_menu(id, columns):
207
+ return dropdown_menu_content(id, [
208
+ {"label": col['label'], "value": col['value']}
209
+ for col in columns if col.get('can_hide', True)
210
+ ])
211
+
212
+ def row_actions_menu(row_id):
213
+ return dropdown_menu("⋮", id=f"row-actions-menu-{row_id}", hx_get=f"/dropdown-menu/dropdown-menu-⋮/{row_id}")
214
+
215
+ def get_row_actions_menu(row_id):
216
+ return dropdown_menu_content(f"row-actions-{row_id}", [
217
+ {"label": "Edit", "icon": "pencil"},
218
+ {"label": "Duplicate", "icon": "copy"},
219
+ {"label": "Delete", "icon": "trash"},
220
+ "separator",
221
+ {"label": "More...", "icon": "more-horizontal"},
222
+ ])
223
+
224
+ def render_row(row_data, row_id, columns):
225
+ return Tr(
226
+ Td(checkbox(f"row-{row_id}", "", class_="row-checkbox")),
227
+ *[Td(row_data[col['value']], data_accessor=col['value']) for col in columns],
228
+ Td(row_actions_menu(row_id)),
229
+ id=f"row-{row_id}"
230
+ ).render()
main.py CHANGED
@@ -18,7 +18,7 @@ from fastapi.exceptions import RequestValidationError
18
  from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
19
  from request import get_payload
20
  from response import fetch_response, fetch_response_stream
21
- from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
22
 
23
  from collections import defaultdict
24
  from typing import List, Dict, Union
@@ -492,20 +492,21 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
492
  else:
493
  engine = "gpt"
494
 
495
- if "claude" not in provider['model'][request.model] \
496
- and "gpt" not in provider['model'][request.model] \
497
- and "gemini" not in provider['model'][request.model] \
 
498
  and parsed_url.netloc != 'api.cloudflare.com' \
499
  and parsed_url.netloc != 'api.cohere.com':
500
  engine = "openrouter"
501
 
502
- if "claude" in provider['model'][request.model] and engine == "vertex":
503
  engine = "vertex-claude"
504
 
505
- if "gemini" in provider['model'][request.model] and engine == "vertex":
506
  engine = "vertex-gemini"
507
 
508
- if "o1-preview" in provider['model'][request.model] or "o1-mini" in provider['model'][request.model]:
509
  engine = "o1"
510
  request.stream = False
511
 
@@ -536,7 +537,7 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
536
  current_info = request_info.get()
537
  try:
538
  if request.stream:
539
- model = provider['model'][request.model]
540
  generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
541
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
542
  response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
@@ -603,7 +604,8 @@ class ModelRequestHandler:
603
  if model == "all":
604
  # 如果模型名为 *,则返回所有模型
605
  for provider in config["providers"]:
606
- for model in provider["model"].keys():
 
607
  provider_rules.append(provider["provider"] + "/" + model)
608
  break
609
  if "/" in model:
@@ -611,15 +613,17 @@ class ModelRequestHandler:
611
  model = model[1:-1]
612
  # 处理带斜杠的模型名
613
  for provider in config['providers']:
614
- if model in provider['model'].keys():
 
615
  provider_rules.append(provider['provider'] + "/" + model)
616
  else:
617
  provider_name = model.split("/")[0]
618
  model_name_split = "/".join(model.split("/")[1:])
619
  models_list = []
620
  for provider in config['providers']:
 
621
  if provider['provider'] == provider_name:
622
- models_list.extend(list(provider['model'].keys()))
623
  # print("models_list", models_list)
624
  # print("model_name", model_name)
625
  # print("model_name_split", model_name_split)
@@ -632,7 +636,8 @@ class ModelRequestHandler:
632
  provider_rules.append(provider_name)
633
  else:
634
  for provider in config['providers']:
635
- if model in provider['model'].keys():
 
636
  provider_rules.append(provider['provider'] + "/" + model)
637
 
638
  provider_list = []
@@ -642,10 +647,13 @@ class ModelRequestHandler:
642
  # print("provider", provider, provider['provider'] == item, item)
643
  if "/" in item:
644
  if provider['provider'] == item.split("/")[0]:
645
- if model_name in provider['model'].keys() and "/".join(item.split("/")[1:]) == model_name:
 
646
  provider_list.append(provider)
 
647
  elif provider['provider'] == item:
648
- if model_name in provider['model'].keys():
 
649
  provider_list.append(provider)
650
  else:
651
  pass
@@ -655,7 +663,8 @@ class ModelRequestHandler:
655
  # if item.split("/")[1] == model_name:
656
  # provider_list.append(provider)
657
  # else:
658
- # if model_name in provider['model'].keys():
 
659
  # provider_list.append(provider)
660
  if is_debug:
661
  for provider in provider_list:
 
18
  from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
19
  from request import get_payload
20
  from response import fetch_response, fetch_response_stream
21
+ from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict
22
 
23
  from collections import defaultdict
24
  from typing import List, Dict, Union
 
492
  else:
493
  engine = "gpt"
494
 
495
+ model_dict = get_model_dict(provider)
496
+ if "claude" not in model_dict[request.model] \
497
+ and "gpt" not in model_dict[request.model] \
498
+ and "gemini" not in model_dict[request.model] \
499
  and parsed_url.netloc != 'api.cloudflare.com' \
500
  and parsed_url.netloc != 'api.cohere.com':
501
  engine = "openrouter"
502
 
503
+ if "claude" in model_dict[request.model] and engine == "vertex":
504
  engine = "vertex-claude"
505
 
506
+ if "gemini" in model_dict[request.model] and engine == "vertex":
507
  engine = "vertex-gemini"
508
 
509
+ if "o1-preview" in model_dict[request.model] or "o1-mini" in model_dict[request.model]:
510
  engine = "o1"
511
  request.stream = False
512
 
 
537
  current_info = request_info.get()
538
  try:
539
  if request.stream:
540
+ model = model_dict[request.model]
541
  generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
542
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
543
  response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
 
604
  if model == "all":
605
  # 如果模型名为 *,则返回所有模型
606
  for provider in config["providers"]:
607
+ model_dict = get_model_dict(provider)
608
+ for model in model_dict.keys():
609
  provider_rules.append(provider["provider"] + "/" + model)
610
  break
611
  if "/" in model:
 
613
  model = model[1:-1]
614
  # 处理带斜杠的模型名
615
  for provider in config['providers']:
616
+ model_dict = get_model_dict(provider)
617
+ if model in model_dict.keys():
618
  provider_rules.append(provider['provider'] + "/" + model)
619
  else:
620
  provider_name = model.split("/")[0]
621
  model_name_split = "/".join(model.split("/")[1:])
622
  models_list = []
623
  for provider in config['providers']:
624
+ model_dict = get_model_dict(provider)
625
  if provider['provider'] == provider_name:
626
+ models_list.extend(list(model_dict.keys()))
627
  # print("models_list", models_list)
628
  # print("model_name", model_name)
629
  # print("model_name_split", model_name_split)
 
636
  provider_rules.append(provider_name)
637
  else:
638
  for provider in config['providers']:
639
+ model_dict = get_model_dict(provider)
640
+ if model in model_dict.keys():
641
  provider_rules.append(provider['provider'] + "/" + model)
642
 
643
  provider_list = []
 
647
  # print("provider", provider, provider['provider'] == item, item)
648
  if "/" in item:
649
  if provider['provider'] == item.split("/")[0]:
650
+ model_dict = get_model_dict(provider)
651
+ if model_name in model_dict.keys() and "/".join(item.split("/")[1:]) == model_name:
652
  provider_list.append(provider)
653
+ # 如果 item 不包含 /,则直接匹配 provider,说明整个渠道所有模型都能用
654
  elif provider['provider'] == item:
655
+ model_dict = get_model_dict(provider)
656
+ if model_name in model_dict.keys():
657
  provider_list.append(provider)
658
  else:
659
  pass
 
663
  # if item.split("/")[1] == model_name:
664
  # provider_list.append(provider)
665
  # else:
666
+ # model_dict = get_model_dict(provider)
667
+ # if model_name in model_dict.keys():
668
  # provider_list.append(provider)
669
  if is_debug:
670
  for provider in provider_list:
models.py CHANGED
@@ -1,5 +1,5 @@
1
  from io import IOBase
2
- from pydantic import BaseModel, Field, model_validator
3
  from typing import List, Dict, Optional, Union, Tuple, Literal, Any
4
  from log_config import logger
5
 
@@ -61,10 +61,16 @@ class ToolChoice(BaseModel):
61
  class BaseRequest(BaseModel):
62
  request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True)
63
 
64
- class JsonSchema(BaseModel):
65
- name: str
66
- schema: Dict[str, Any]
 
 
 
 
 
67
 
 
68
  class ResponseFormat(BaseModel):
69
  type: Literal["text", "json_object", "json_schema"]
70
  json_schema: Optional[JsonSchema] = None
 
1
  from io import IOBase
2
+ from pydantic import BaseModel, Field, model_validator, ConfigDict
3
  from typing import List, Dict, Optional, Union, Tuple, Literal, Any
4
  from log_config import logger
5
 
 
61
  class BaseRequest(BaseModel):
62
  request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True)
63
 
64
+ def create_json_schema_class():
65
+ class JsonSchema(BaseModel):
66
+ name: str
67
+
68
+ model_config = ConfigDict(protected_namespaces=())
69
+
70
+ JsonSchema.__annotations__['schema'] = Dict[str, Any]
71
+ return JsonSchema
72
 
73
+ JsonSchema = create_json_schema_class()
74
  class ResponseFormat(BaseModel):
75
  type: Literal["text", "json_object", "json_schema"]
76
  json_schema: Optional[JsonSchema] = None
request.py CHANGED
@@ -6,7 +6,7 @@ import base64
6
  import urllib.parse
7
 
8
  from models import RequestModel
9
- from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
10
 
11
  import imghdr
12
 
@@ -120,13 +120,14 @@ async def get_gemini_payload(request, engine, provider):
120
  headers = {
121
  'Content-Type': 'application/json'
122
  }
123
- model = provider['model'][request.model]
 
124
  gemini_stream = "streamGenerateContent"
125
  url = provider['base_url']
126
  if url.endswith("v1beta"):
127
- url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['api'].next())
128
  if url.endswith("v1"):
129
- url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['api'].next())
130
 
131
  messages = []
132
  systemInstruction = None
@@ -312,7 +313,8 @@ async def get_vertex_gemini_payload(request, engine, provider):
312
  project_id = provider.get("project_id")
313
 
314
  gemini_stream = "streamGenerateContent"
315
- model = provider['model'][request.model]
 
316
  location = gem
317
  url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream)
318
 
@@ -449,7 +451,8 @@ async def get_vertex_claude_payload(request, engine, provider):
449
  if provider.get("project_id"):
450
  project_id = provider.get("project_id")
451
 
452
- model = provider['model'][request.model]
 
453
  if "claude-3-5-sonnet" in model:
454
  location = c35s
455
  elif "claude-3-opus" in model:
@@ -460,7 +463,7 @@ async def get_vertex_claude_payload(request, engine, provider):
460
  location = c3h
461
 
462
  claude_stream = "streamRawPredict"
463
- url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
464
 
465
  messages = []
466
  system_prompt = None
@@ -534,7 +537,8 @@ async def get_vertex_claude_payload(request, engine, provider):
534
  else:
535
  message_index = message_index + 1
536
 
537
- model = provider['model'][request.model]
 
538
  payload = {
539
  "anthropic_version": "vertex-2023-10-16",
540
  "messages": messages,
@@ -593,7 +597,7 @@ async def get_gpt_payload(request, engine, provider):
593
  'Content-Type': 'application/json',
594
  }
595
  if provider.get("api"):
596
- headers['Authorization'] = f"Bearer {provider['api'].next()}"
597
  url = provider['base_url']
598
 
599
  messages = []
@@ -633,7 +637,8 @@ async def get_gpt_payload(request, engine, provider):
633
  else:
634
  messages.append({"role": msg.role, "content": content})
635
 
636
- model = provider['model'][request.model]
 
637
  payload = {
638
  "model": model,
639
  "messages": messages,
@@ -659,7 +664,7 @@ async def get_openrouter_payload(request, engine, provider):
659
  'Content-Type': 'application/json'
660
  }
661
  if provider.get("api"):
662
- headers['Authorization'] = f"Bearer {provider['api'].next()}"
663
 
664
  url = provider['base_url']
665
 
@@ -691,7 +696,8 @@ async def get_openrouter_payload(request, engine, provider):
691
  else:
692
  messages.append({"role": msg.role, "content": content})
693
 
694
- model = provider['model'][request.model]
 
695
  payload = {
696
  "model": model,
697
  "messages": messages,
@@ -725,7 +731,7 @@ async def get_cohere_payload(request, engine, provider):
725
  'Content-Type': 'application/json'
726
  }
727
  if provider.get("api"):
728
- headers['Authorization'] = f"Bearer {provider['api'].next()}"
729
 
730
  url = provider['base_url']
731
 
@@ -753,7 +759,8 @@ async def get_cohere_payload(request, engine, provider):
753
  else:
754
  messages.append({"role": role_map[msg.role], "message": content})
755
 
756
- model = provider['model'][request.model]
 
757
  chat_history = messages[:-1]
758
  query = messages[-1].get("message")
759
  payload = {
@@ -792,9 +799,10 @@ async def get_cloudflare_payload(request, engine, provider):
792
  'Content-Type': 'application/json'
793
  }
794
  if provider.get("api"):
795
- headers['Authorization'] = f"Bearer {provider['api'].next()}"
796
 
797
- model = provider['model'][request.model]
 
798
  url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
799
 
800
  msg = request.messages[-1]
@@ -808,7 +816,7 @@ async def get_cloudflare_payload(request, engine, provider):
808
  content = msg.content
809
  name = msg.name
810
 
811
- model = provider['model'][request.model]
812
  payload = {
813
  "prompt": content,
814
  }
@@ -841,7 +849,7 @@ async def get_o1_payload(request, engine, provider):
841
  'Content-Type': 'application/json'
842
  }
843
  if provider.get("api"):
844
- headers['Authorization'] = f"Bearer {provider['api'].next()}"
845
 
846
  url = provider['base_url']
847
 
@@ -863,7 +871,8 @@ async def get_o1_payload(request, engine, provider):
863
  elif msg.role != "system":
864
  messages.append({"role": msg.role, "content": content})
865
 
866
- model = provider['model'][request.model]
 
867
  payload = {
868
  "model": model,
869
  "messages": messages,
@@ -912,10 +921,11 @@ async def gpt2claude_tools_json(json_dict):
912
  return json_dict
913
 
914
  async def get_claude_payload(request, engine, provider):
915
- model = provider['model'][request.model]
 
916
  headers = {
917
  "content-type": "application/json",
918
- "x-api-key": f"{provider['api'].next()}",
919
  "anthropic-version": "2023-06-01",
920
  "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
921
  }
@@ -993,7 +1003,8 @@ async def get_claude_payload(request, engine, provider):
993
  else:
994
  message_index = message_index + 1
995
 
996
- model = provider['model'][request.model]
 
997
  payload = {
998
  "model": model,
999
  "messages": messages,
@@ -1051,12 +1062,13 @@ async def get_claude_payload(request, engine, provider):
1051
  return url, headers, payload
1052
 
1053
  async def get_dalle_payload(request, engine, provider):
1054
- model = provider['model'][request.model]
 
1055
  headers = {
1056
  "Content-Type": "application/json",
1057
  }
1058
  if provider.get("api"):
1059
- headers['Authorization'] = f"Bearer {provider['api'].next()}"
1060
  url = provider['base_url']
1061
  url = BaseAPI(url).image_url
1062
 
@@ -1070,12 +1082,13 @@ async def get_dalle_payload(request, engine, provider):
1070
  return url, headers, payload
1071
 
1072
  async def get_whisper_payload(request, engine, provider):
1073
- model = provider['model'][request.model]
 
1074
  headers = {
1075
  # "Content-Type": "multipart/form-data",
1076
  }
1077
  if provider.get("api"):
1078
- headers['Authorization'] = f"Bearer {provider['api'].next()}"
1079
  url = provider['base_url']
1080
  url = BaseAPI(url).audio_transcriptions
1081
 
@@ -1096,12 +1109,13 @@ async def get_whisper_payload(request, engine, provider):
1096
  return url, headers, payload
1097
 
1098
  async def get_moderation_payload(request, engine, provider):
1099
- model = provider['model'][request.model]
 
1100
  headers = {
1101
  "Content-Type": "application/json",
1102
  }
1103
  if provider.get("api"):
1104
- headers['Authorization'] = f"Bearer {provider['api'].next()}"
1105
  url = provider['base_url']
1106
  url = BaseAPI(url).moderations
1107
 
 
6
  import urllib.parse
7
 
8
  from models import RequestModel
9
+ from utils import c35s, c3s, c3o, c3h, gem, BaseAPI, get_model_dict, provider_api_circular_list
10
 
11
  import imghdr
12
 
 
120
  headers = {
121
  'Content-Type': 'application/json'
122
  }
123
+ model_dict = get_model_dict(provider)
124
+ model = model_dict[request.model]
125
  gemini_stream = "streamGenerateContent"
126
  url = provider['base_url']
127
  if url.endswith("v1beta"):
128
+ url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
129
  if url.endswith("v1"):
130
+ url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
131
 
132
  messages = []
133
  systemInstruction = None
 
313
  project_id = provider.get("project_id")
314
 
315
  gemini_stream = "streamGenerateContent"
316
+ model_dict = get_model_dict(provider)
317
+ model = model_dict[request.model]
318
  location = gem
319
  url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream)
320
 
 
451
  if provider.get("project_id"):
452
  project_id = provider.get("project_id")
453
 
454
+ model_dict = get_model_dict(provider)
455
+ model = model_dict[request.model]
456
  if "claude-3-5-sonnet" in model:
457
  location = c35s
458
  elif "claude-3-opus" in model:
 
463
  location = c3h
464
 
465
  claude_stream = "streamRawPredict"
466
+ url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=await location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
467
 
468
  messages = []
469
  system_prompt = None
 
537
  else:
538
  message_index = message_index + 1
539
 
540
+ model_dict = get_model_dict(provider)
541
+ model = model_dict[request.model]
542
  payload = {
543
  "anthropic_version": "vertex-2023-10-16",
544
  "messages": messages,
 
597
  'Content-Type': 'application/json',
598
  }
599
  if provider.get("api"):
600
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
601
  url = provider['base_url']
602
 
603
  messages = []
 
637
  else:
638
  messages.append({"role": msg.role, "content": content})
639
 
640
+ model_dict = get_model_dict(provider)
641
+ model = model_dict[request.model]
642
  payload = {
643
  "model": model,
644
  "messages": messages,
 
664
  'Content-Type': 'application/json'
665
  }
666
  if provider.get("api"):
667
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
668
 
669
  url = provider['base_url']
670
 
 
696
  else:
697
  messages.append({"role": msg.role, "content": content})
698
 
699
+ model_dict = get_model_dict(provider)
700
+ model = model_dict[request.model]
701
  payload = {
702
  "model": model,
703
  "messages": messages,
 
731
  'Content-Type': 'application/json'
732
  }
733
  if provider.get("api"):
734
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
735
 
736
  url = provider['base_url']
737
 
 
759
  else:
760
  messages.append({"role": role_map[msg.role], "message": content})
761
 
762
+ model_dict = get_model_dict(provider)
763
+ model = model_dict[request.model]
764
  chat_history = messages[:-1]
765
  query = messages[-1].get("message")
766
  payload = {
 
799
  'Content-Type': 'application/json'
800
  }
801
  if provider.get("api"):
802
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
803
 
804
+ model_dict = get_model_dict(provider)
805
+ model = model_dict[request.model]
806
  url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
807
 
808
  msg = request.messages[-1]
 
816
  content = msg.content
817
  name = msg.name
818
 
819
+ model = model_dict[request.model]
820
  payload = {
821
  "prompt": content,
822
  }
 
849
  'Content-Type': 'application/json'
850
  }
851
  if provider.get("api"):
852
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
853
 
854
  url = provider['base_url']
855
 
 
871
  elif msg.role != "system":
872
  messages.append({"role": msg.role, "content": content})
873
 
874
+ model_dict = get_model_dict(provider)
875
+ model = model_dict[request.model]
876
  payload = {
877
  "model": model,
878
  "messages": messages,
 
921
  return json_dict
922
 
923
  async def get_claude_payload(request, engine, provider):
924
+ model_dict = get_model_dict(provider)
925
+ model = model_dict[request.model]
926
  headers = {
927
  "content-type": "application/json",
928
+ "x-api-key": f"{await provider_api_circular_list[provider['provider']].next()}",
929
  "anthropic-version": "2023-06-01",
930
  "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
931
  }
 
1003
  else:
1004
  message_index = message_index + 1
1005
 
1006
+ model_dict = get_model_dict(provider)
1007
+ model = model_dict[request.model]
1008
  payload = {
1009
  "model": model,
1010
  "messages": messages,
 
1062
  return url, headers, payload
1063
 
1064
  async def get_dalle_payload(request, engine, provider):
1065
+ model_dict = get_model_dict(provider)
1066
+ model = model_dict[request.model]
1067
  headers = {
1068
  "Content-Type": "application/json",
1069
  }
1070
  if provider.get("api"):
1071
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
1072
  url = provider['base_url']
1073
  url = BaseAPI(url).image_url
1074
 
 
1082
  return url, headers, payload
1083
 
1084
  async def get_whisper_payload(request, engine, provider):
1085
+ model_dict = get_model_dict(provider)
1086
+ model = model_dict[request.model]
1087
  headers = {
1088
  # "Content-Type": "multipart/form-data",
1089
  }
1090
  if provider.get("api"):
1091
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
1092
  url = provider['base_url']
1093
  url = BaseAPI(url).audio_transcriptions
1094
 
 
1109
  return url, headers, payload
1110
 
1111
  async def get_moderation_payload(request, engine, provider):
1112
+ model_dict = get_model_dict(provider)
1113
+ model = model_dict[request.model]
1114
  headers = {
1115
  "Content-Type": "application/json",
1116
  }
1117
  if provider.get("api"):
1118
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
1119
  url = provider['base_url']
1120
  url = BaseAPI(url).moderations
1121
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- pyyaml
2
  pytest
3
  uvicorn
4
  fastapi
@@ -7,6 +7,7 @@ greenlet
7
  aiosqlite
8
  sqlalchemy
9
  watchfiles
 
10
  httpx[http2]
11
  cryptography
12
  python-multipart
 
1
+ xue
2
  pytest
3
  uvicorn
4
  fastapi
 
7
  aiosqlite
8
  sqlalchemy
9
  watchfiles
10
+ ruamel.yaml
11
  httpx[http2]
12
  cryptography
13
  python-multipart
test/test_ruamel_yaml.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ruamel.yaml import YAML
2
+
3
+ # 假设我们有以下 YAML 内容
4
+ yaml_content = """
5
+ # 这是顶级注释
6
+ key1: value1 # 行尾注释
7
+ key2: value2
8
+
9
+ # 这是嵌套结构的注释
10
+ nested:
11
+ subkey1: subvalue1
12
+ subkey2: subvalue2 # 嵌套的行尾注释
13
+
14
+ # 列表的注释
15
+ list_key:
16
+ - item1
17
+ - item2 # 列表项的注释
18
+ """
19
+
20
+ # 创建 YAML 对象
21
+ yaml = YAML()
22
+ yaml.preserve_quotes = True
23
+ yaml.indent(mapping=2, sequence=4, offset=2)
24
+
25
+ with open('api.yaml', 'r', encoding='utf-8') as file:
26
+ data = yaml.load(file)
27
+
28
+ # data = yaml.load(yaml_content)
29
+ # 加载 YAML 数据
30
+ print(data)
31
+
32
+ # # 修改数据
33
+ # data['key1'] = 'new_value1'
34
+ # data['nested']['subkey1'] = 'new_subvalue1'
35
+ # data['list_key'].append('new_item')
36
+
37
+ # 将修改后的数据写回文件(这里我们使用 StringIO 来模拟文件操作)
38
+ # from io import StringIO
39
+ # output = StringIO()
40
+ # yaml.dump(data, output)
41
+ # print(output.getvalue())
42
+
43
+ with open('formatted.yaml', 'w', encoding='utf-8') as file:
44
+ yaml.dump(data, file)
test/xue/test_dropdown_sheet.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import HTMLResponse
3
+ from xue import HTML, Head, Body, Div, xue_initialize, Script
4
+ from xue.components import dropdown, sheet, button, form, input
5
+
6
+ xue_initialize(tailwind=True)
7
+ app = FastAPI()
8
+
9
+ @app.get("/", response_class=HTMLResponse)
10
+ async def root():
11
+ result = HTML(
12
+ Head(
13
+ title="Dropdown with Edit Sheet Example",
14
+ ),
15
+ Body(
16
+ Div(
17
+ dropdown.dropdown_menu("Actions"),
18
+ Div(id="sheet-container"), # 这里是 sheet 将被加载的地方
19
+ class_="container mx-auto p-4"
20
+ )
21
+ )
22
+ ).render()
23
+ print(result)
24
+ return result
25
+
26
+ @app.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse)
27
+ async def get_dropdown_menu_content(menu_id: str):
28
+ items = [
29
+ {
30
+ "icon": "pencil",
31
+ "label": "Edit",
32
+ "hx-get": "/edit-sheet",
33
+ "hx-target": "#sheet-container",
34
+ "hx-swap": "innerHTML"
35
+ },
36
+ {"icon": "trash", "label": "Delete"},
37
+ {"icon": "copy", "label": "Duplicate"},
38
+ ]
39
+ result = dropdown.dropdown_menu_content(menu_id, items).render()
40
+ print("dropdown-menu result", result)
41
+ return result
42
+
43
+ @app.get("/edit-sheet", response_class=HTMLResponse)
44
+ async def get_edit_sheet():
45
+ edit_sheet_content = sheet.SheetContent(
46
+ sheet.SheetHeader(
47
+ sheet.SheetTitle("Edit Item"),
48
+ sheet.SheetDescription("Make changes to your item here.")
49
+ ),
50
+ sheet.SheetBody(
51
+ form.Form(
52
+ form.FormField("Name", "name", placeholder="Enter item name"),
53
+ form.FormField("Description", "description", placeholder="Enter item description"),
54
+ Div(
55
+ button.button("Save", class_="bg-blue-500 text-white"),
56
+ button.button("Cancel", class_="bg-gray-300 text-gray-700 ml-2", data_close_sheet="true"),
57
+ class_="flex justify-end mt-4"
58
+ ),
59
+ class_="space-y-4"
60
+ )
61
+ )
62
+ )
63
+
64
+ result = sheet.Sheet(
65
+ "edit-sheet",
66
+ Div(),
67
+ edit_sheet_content,
68
+ width="80%",
69
+ max_width="800px"
70
+ ).render()
71
+ return result
72
+
73
+ if __name__ == "__main__":
74
+ import uvicorn
75
+ uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
test/xue/test_form_uni_api.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Form as FastAPIForm
2
+ from fastapi.responses import HTMLResponse
3
+ from xue import HTML, Head, Body, Div, xue_initialize, Strong, Span, Ul, Li
4
+ from xue.components import form, button, checkbox, input
5
+ from xue.components.model_config_row import model_config_row
6
+ from typing import List, Optional
7
+ import time
8
+
9
+ xue_initialize(tailwind=True)
10
+ app = FastAPI()
11
+
12
+ @app.get("/", response_class=HTMLResponse)
13
+ async def root():
14
+ result = HTML(
15
+ Head(
16
+ title="Provider Configuration Form"
17
+ ),
18
+ Body(
19
+ Div(
20
+ form.Form(
21
+ form.FormField("Provider", "provider", placeholder="Enter provider name", required=True),
22
+ form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True),
23
+ form.FormField("API Key", "api_key", type="password", placeholder="Enter API key"),
24
+ Div(
25
+ Div("Models", class_="text-lg font-semibold mb-2"),
26
+ Div(
27
+ model_config_row("model1", "gpt-4o: deepbricks-gpt-4o-mini", True),
28
+ model_config_row("model2", "gpt-4o"),
29
+ model_config_row("model3", "gpt-3.5-turbo"),
30
+ model_config_row("model4", "claude-3-5-sonnet-20240620: claude-3-5-sonnet"),
31
+ model_config_row("model5", "o1-mini-all"),
32
+ model_config_row("model6", "o1-preview-all"),
33
+ model_config_row("model7", "whisper-1"),
34
+ id="models-container"
35
+ ),
36
+ button.button(
37
+ "Add Model",
38
+ class_="mt-2",
39
+ hx_post="/add-model",
40
+ hx_target="#models-container",
41
+ hx_swap="beforeend"
42
+ ),
43
+ class_="mb-4"
44
+ ),
45
+ Div(
46
+ checkbox.checkbox("tools", "Enable Tools", checked=True),
47
+ class_="mb-4"
48
+ ),
49
+ form.FormField("Notes", "notes", placeholder="Enter any additional notes"),
50
+ Div(
51
+ button.button("Submit", class_="bg-blue-500 text-white"),
52
+ button.button("Cancel", class_="bg-gray-300 text-gray-700 ml-2"),
53
+ class_="flex justify-end mt-4"
54
+ ),
55
+ hx_post="/submit",
56
+ hx_swap="outerHTML",
57
+ class_="space-y-4"
58
+ ),
59
+ class_="container mx-auto p-4 max-w-2xl"
60
+ )
61
+ )
62
+ ).render()
63
+ print(result)
64
+ return result
65
+
66
+ @app.post("/add-model", response_class=HTMLResponse)
67
+ async def add_model():
68
+ new_model_id = f"model{hash(str(time.time()))}" # 生成一个唯一的ID
69
+ new_model = model_config_row(new_model_id).render()
70
+ return new_model
71
+
72
+ def form_success_message(provider, base_url, api_key, models, tools_enabled, notes):
73
+ return Div(
74
+ Strong("Success!", class_="font-bold"),
75
+ Span("Form submitted successfully.", class_="block sm:inline"),
76
+ Ul(
77
+ Li(f"Provider: {provider}"),
78
+ Li(f"Base URL: {base_url}"),
79
+ Li(f"API Key: {'*' * len(api_key)}"),
80
+ Li(f"Models: {', '.join(models)}"),
81
+ Li(f"Tools Enabled: {'Yes' if tools_enabled else 'No'}"),
82
+ Li(f"Notes: {notes}"),
83
+ class_="mt-3"
84
+ ),
85
+ class_="bg-green-100 border border-green-400 text-green-700 px-4 py-3 rounded relative",
86
+ role="alert"
87
+ )
88
+
89
+ @app.post("/submit", response_class=HTMLResponse)
90
+ async def submit_form(
91
+ provider: str = FastAPIForm(...),
92
+ base_url: str = FastAPIForm(...),
93
+ api_key: str = FastAPIForm(...),
94
+ models: List[str] = FastAPIForm([]),
95
+ tools: Optional[str] = FastAPIForm(None),
96
+ notes: Optional[str] = FastAPIForm(None)
97
+ ):
98
+ # 处理提交的数据
99
+ print(f"Received: provider={provider}, base_url={base_url}, api_key={api_key}")
100
+ print(f"Models: {models}")
101
+ print(f"Tools Enabled: {tools is not None}")
102
+ print(f"Notes: {notes}")
103
+
104
+ # 返回处理结果
105
+ return form_success_message(
106
+ provider,
107
+ base_url,
108
+ api_key,
109
+ models,
110
+ tools is not None,
111
+ notes or "No notes provided"
112
+ ).render()
113
+
114
+ if __name__ == "__main__":
115
+ import uvicorn
116
+ uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
test/xue/test_home.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi import Form as FastapiForm, HTTPException, Depends
3
+ from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
4
+ from fastapi.security import APIKeyHeader
5
+ from typing import Optional, List
6
+
7
+ from xue import HTML, Head, Body, Div, xue_initialize, Script
8
+ from xue.components.menubar import (
9
+ Menubar, MenubarMenu, MenubarTrigger, MenubarContent,
10
+ MenubarItem, MenubarSeparator
11
+ )
12
+ from xue.components import input
13
+ from xue.components import dropdown, sheet, form, button, checkbox
14
+ from xue.components.model_config_row import model_config_row
15
+ import time
16
+
17
+ import sys
18
+ import os
19
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
20
+ from components.provider_table import data_table
21
+
22
+
23
+ from ruamel.yaml import YAML
24
+ yaml = YAML()
25
+ yaml.preserve_quotes = True
26
+ yaml.indent(mapping=2, sequence=4, offset=2)
27
+
28
+ xue_initialize(tailwind=True)
29
+
30
+ from starlette.middleware.base import BaseHTTPMiddleware
31
+ import logging
32
+
33
+ logging.basicConfig(level=logging.INFO)
34
+ logger = logging.getLogger(__name__)
35
+
36
+ class RequestBodyLoggerMiddleware(BaseHTTPMiddleware):
37
+ async def dispatch(self, request: Request, call_next):
38
+ if request.method == "POST" and request.url.path.startswith("/submit/"):
39
+ # if request.method == "POST":
40
+ body = await request.body()
41
+ logger.info(f"Request body for {request.url.path}: {body.decode()}")
42
+
43
+ response = await call_next(request)
44
+ return response
45
+
46
+ from utils import load_config
47
+ from contextlib import asynccontextmanager
48
+ @asynccontextmanager
49
+ async def lifespan(app: FastAPI):
50
+ # app.state.client = httpx.AsyncClient(timeout=timeout)
51
+ app.state.config, app.state.api_keys_db, app.state.api_list = await load_config()
52
+ for item in app.state.api_keys_db:
53
+ if item.get("role") == "admin":
54
+ app.state.admin_api_key = item.get("api")
55
+ if not hasattr(app.state, "admin_api_key"):
56
+ if len(app.state.api_keys_db) >= 1:
57
+ app.state.admin_api_key = app.state.api_keys_db[0].get("api")
58
+ else:
59
+ raise Exception("No admin API key found")
60
+
61
+ global data
62
+ # providers_data = app.state.config["providers"]
63
+
64
+ # print("data", data)
65
+ yield
66
+ # 关闭时的代码
67
+ await app.state.client.aclose()
68
+
69
+ app = FastAPI(lifespan=lifespan)
70
+ # app.add_middleware(RequestBodyLoggerMiddleware)
71
+ app.add_middleware(RequestBodyLoggerMiddleware)
72
+
73
+ data_table_columns = [
74
+ # {"label": "Status", "value": "status", "sortable": True},
75
+ {"label": "Provider", "value": "provider", "sortable": True},
76
+ {"label": "Base url", "value": "base_url", "sortable": True},
77
+ # {"label": "Engine", "value": "engine", "sortable": True},
78
+ {"label": "Tools", "value": "tools", "sortable": True},
79
+ ]
80
+
81
+ API_KEY_NAME = "X-API-Key"
82
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
83
+
84
+ @app.get("/login", response_class=HTMLResponse)
85
+ async def login_page():
86
+ return HTML(
87
+ Head(title="登录"),
88
+ Body(
89
+ Div(
90
+ form.Form(
91
+ form.FormField("API Key", "x_api_key", type="password", placeholder="输入API密钥", required=True),
92
+ Div(id="error-message", class_="text-red-500 mt-2"),
93
+ Div(
94
+ button.button("提交", variant="primary", type="submit"),
95
+ class_="flex justify-end mt-4"
96
+ ),
97
+ hx_post="/verify-api-key",
98
+ hx_target="#error-message",
99
+ hx_swap="innerHTML",
100
+ class_="space-y-4"
101
+ ),
102
+ class_="container mx-auto p-4 max-w-md"
103
+ )
104
+ )
105
+ ).render()
106
+
107
+
108
+ @app.post("/verify-api-key", response_class=HTMLResponse)
109
+ async def verify_api_key(x_api_key: str = FastapiForm(...)):
110
+ if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥
111
+ response = JSONResponse(content={"success": True})
112
+ response.headers["HX-Redirect"] = "/" # 添加这一行
113
+ response.set_cookie(
114
+ key="x_api_key",
115
+ value=x_api_key,
116
+ httponly=True,
117
+ max_age=1800, # 30分钟
118
+ secure=False, # 在开发环境中设置为False,生产环境中使用HTTPS时设置为True
119
+ samesite="lax" # 改为"lax"以允许重定向时携带cookie
120
+ )
121
+ return response
122
+ else:
123
+ return Div("无效的API密钥", class_="text-red-500").render()
124
+
125
+ async def get_api_key(request: Request, x_api_key: Optional[str] = Depends(api_key_header)):
126
+ if not x_api_key:
127
+ x_api_key = request.cookies.get("x_api_key") or request.query_params.get("x_api_key")
128
+ # print(f"Cookie x_api_key: {request.cookies.get('x_api_key')}") # 添加此行
129
+ # print(f"Query param x_api_key: {request.query_params.get('x_api_key')}") # 添加此行
130
+ # print(f"Header x_api_key: {x_api_key}") # 添加此��
131
+ # logger.info(f"x_api_key: {x_api_key} {x_api_key == 'your_admin_api_key'}")
132
+
133
+ if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥
134
+ return x_api_key
135
+ else:
136
+ return None
137
+
138
+ @app.get("/", response_class=HTMLResponse)
139
+ async def root(x_api_key: str = Depends(get_api_key)):
140
+ if not x_api_key:
141
+ return RedirectResponse(url="/login", status_code=303)
142
+
143
+ result = HTML(
144
+ Head(
145
+ Script("""
146
+ document.addEventListener('DOMContentLoaded', function() {
147
+ const filterInput = document.getElementById('users-table-filter');
148
+ filterInput.addEventListener('input', function() {
149
+ const filterValue = this.value;
150
+ htmx.ajax('GET', `/filter-table?filter=${filterValue}`, '#users-table');
151
+ });
152
+ });
153
+ """),
154
+ title="Menubar Example"
155
+ ),
156
+ Body(
157
+ Div(
158
+ Menubar(
159
+ MenubarMenu(
160
+ MenubarTrigger("File", "file-menu"),
161
+ MenubarContent(
162
+ MenubarItem("New Tab", shortcut="⌘T"),
163
+ MenubarItem("New Window", shortcut="⌘N"),
164
+ MenubarItem("New Incognito Window", disabled=True),
165
+ MenubarSeparator(),
166
+ MenubarItem("Print...", shortcut="⌘P"),
167
+ ),
168
+ id="file-menu"
169
+ ),
170
+ MenubarMenu(
171
+ MenubarTrigger("Edit", "edit-menu"),
172
+ MenubarContent(
173
+ MenubarItem("Undo", shortcut="⌘Z"),
174
+ MenubarItem("Redo", shortcut="⇧⌘Z"),
175
+ MenubarSeparator(),
176
+ MenubarItem("Cut"),
177
+ MenubarItem("Copy"),
178
+ MenubarItem("Paste"),
179
+ ),
180
+ id="edit-menu"
181
+ ),
182
+ MenubarMenu(
183
+ MenubarTrigger("View", "view-menu"),
184
+ MenubarContent(
185
+ MenubarItem("Always Show Bookmarks Bar"),
186
+ MenubarItem("Always Show Full URLs"),
187
+ MenubarSeparator(),
188
+ MenubarItem("Reload", shortcut="⌘R"),
189
+ MenubarItem("Force Reload", shortcut="⇧⌘R", disabled=True),
190
+ MenubarSeparator(),
191
+ MenubarItem("Toggle Fullscreen"),
192
+ MenubarItem("Hide Sidebar"),
193
+ ),
194
+ id="view-menu"
195
+ ),
196
+ ),
197
+ class_="p-4"
198
+ ),
199
+ Div(
200
+ data_table(data_table_columns, app.state.config["providers"], "users-table"),
201
+ class_="p-4"
202
+ ),
203
+ Div(id="sheet-container"), # 这里是 sheet 将被加载的地方
204
+ class_="container mx-auto",
205
+ id="body"
206
+ )
207
+ ).render()
208
+ # print(result)
209
+ return result
210
+
211
+ @app.get("/dropdown-menu/{menu_id}/{row_id}", response_class=HTMLResponse)
212
+ async def get_columns_menu(menu_id: str, row_id: str):
213
+ columns = [
214
+ {
215
+ "label": "Edit",
216
+ "value": "edit",
217
+ "hx-get": f"/edit-sheet/{row_id}",
218
+ "hx-target": "#sheet-container",
219
+ "hx-swap": "innerHTML"
220
+ },
221
+ {
222
+ "label": "Duplicate",
223
+ "value": "duplicate",
224
+ "hx-post": f"/duplicate/{row_id}",
225
+ "hx-target": "body",
226
+ "hx-swap": "outerHTML"
227
+ },
228
+ {
229
+ "label": "Delete",
230
+ "value": "delete",
231
+ "hx-delete": f"/delete/{row_id}",
232
+ "hx-target": "body",
233
+ "hx-swap": "outerHTML",
234
+ "hx-confirm": "确定要删除这个配置吗?"
235
+ },
236
+ ]
237
+ result = dropdown.dropdown_menu_content(menu_id, columns).render()
238
+ print(result)
239
+ return result
240
+
241
+ @app.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse)
242
+ async def get_columns_menu(menu_id: str):
243
+ result = dropdown.dropdown_menu_content(menu_id, data_table_columns).render()
244
+ print(result)
245
+ return result
246
+
247
+ @app.get("/filter-table", response_class=HTMLResponse)
248
+ async def filter_table(filter: str = ""):
249
+ filtered_data = [
250
+ provider for provider in app.state.config["providers"]
251
+ if filter.lower() in str(provider["provider"]).lower() or
252
+ filter.lower() in str(provider["base_url"]).lower() or
253
+ filter.lower() in str(provider["tools"]).lower()
254
+ ]
255
+ return data_table(data_table_columns, filtered_data, "users-table", with_filter=False).render()
256
+
257
+ @app.post("/add-model", response_class=HTMLResponse)
258
+ async def add_model():
259
+ new_model_id = f"model{hash(str(time.time()))}" # 生成一个唯一的ID
260
+ new_model = model_config_row(new_model_id).render()
261
+ return new_model
262
+
263
+ @app.get("/edit-sheet/{row_id}", response_class=HTMLResponse)
264
+ async def get_edit_sheet(row_id: str, x_api_key: str = Depends(get_api_key)):
265
+ row_data = get_row_data(row_id)
266
+ print("row_data", row_data)
267
+
268
+ model_list = []
269
+ for index, model in enumerate(row_data["model"]):
270
+ if isinstance(model, str):
271
+ model_list.append(model_config_row(f"model{index}", model, "", True))
272
+ if isinstance(model, dict):
273
+ # print("model", model, list(model.items())[0])
274
+ key, value = list(model.items())[0]
275
+ model_list.append(model_config_row(f"model{index}", key, value, True))
276
+
277
+ sheet_id = "edit-sheet"
278
+ edit_sheet_content = sheet.SheetContent(
279
+ sheet.SheetHeader(
280
+ sheet.SheetTitle("Edit Item"),
281
+ sheet.SheetDescription("Make changes to your item here.")
282
+ ),
283
+ sheet.SheetBody(
284
+ Div(
285
+ form.Form(
286
+ form.FormField("Provider", "provider", value=row_data["provider"], placeholder="Enter provider name", required=True),
287
+ form.FormField("Base URL", "base_url", value=row_data["base_url"], placeholder="Enter base URL", required=True),
288
+ form.FormField("API Key", "api_key", value=row_data["api"], type="text", placeholder="Enter API key"),
289
+ Div(
290
+ Div("Models", class_="text-lg font-semibold mb-2"),
291
+ Div(
292
+ *model_list,
293
+ id="models-container"
294
+ ),
295
+ button.button(
296
+ "Add Model",
297
+ class_="mt-2",
298
+ hx_post="/add-model",
299
+ hx_target="#models-container",
300
+ hx_swap="beforeend"
301
+ ),
302
+ class_="mb-4"
303
+ ),
304
+ Div(
305
+ checkbox.checkbox("tools", "Enable Tools", checked=row_data["tools"], name="tools"),
306
+ class_="mb-4"
307
+ ),
308
+ form.FormField("Notes", "notes", value=row_data.get("notes", ""), placeholder="Enter any additional notes"),
309
+ Div(
310
+ button.button("Submit", variant="primary", type="submit"),
311
+ button.button("Cancel", variant="outline", type="button", class_="ml-2", onclick=f"toggleSheet('{sheet_id}')"),
312
+ class_="flex justify-end mt-4"
313
+ ),
314
+ hx_post=f"/submit/{row_id}",
315
+ hx_swap="outerHTML",
316
+ hx_target="body",
317
+ class_="space-y-4"
318
+ ),
319
+ class_="container mx-auto p-4 max-w-2xl"
320
+ )
321
+ )
322
+ )
323
+
324
+ result = sheet.Sheet(
325
+ sheet_id,
326
+ Div(),
327
+ edit_sheet_content,
328
+ width="80%",
329
+ max_width="800px"
330
+ ).render()
331
+ return result
332
+
333
+ @app.get("/add-provider-sheet", response_class=HTMLResponse)
334
+ async def get_add_provider_sheet():
335
+ edit_sheet_content = sheet.SheetContent(
336
+ sheet.SheetHeader(
337
+ sheet.SheetTitle("Add New Provider"),
338
+ sheet.SheetDescription("Enter details for the new provider.")
339
+ ),
340
+ sheet.SheetBody(
341
+ Div(
342
+ form.Form(
343
+ form.FormField("Provider", "provider", placeholder="Enter provider name", required=True),
344
+ form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True),
345
+ form.FormField("API Key", "api_key", type="text", placeholder="Enter API key"),
346
+ Div(
347
+ Div("Models", class_="text-lg font-semibold mb-2"),
348
+ Div(id="models-container"),
349
+ button.button(
350
+ "Add Model",
351
+ class_="mt-2",
352
+ hx_post="/add-model",
353
+ hx_target="#models-container",
354
+ hx_swap="beforeend"
355
+ ),
356
+ class_="mb-4"
357
+ ),
358
+ Div(
359
+ checkbox.checkbox("tools", "Enable Tools", name="tools"),
360
+ class_="mb-4"
361
+ ),
362
+ form.FormField("Notes", "notes", placeholder="Enter any additional notes"),
363
+ Div(
364
+ button.button("Submit", variant="primary", type="submit"),
365
+ button.button("Cancel", variant="outline", class_="ml-2"),
366
+ class_="flex justify-end mt-4"
367
+ ),
368
+ hx_post="/submit/new",
369
+ hx_swap="outerHTML",
370
+ hx_target="body",
371
+ class_="space-y-4"
372
+ ),
373
+ class_="container mx-auto p-4 max-w-2xl"
374
+ )
375
+ )
376
+ )
377
+
378
+ result = sheet.Sheet(
379
+ "add-provider-sheet",
380
+ Div(),
381
+ edit_sheet_content,
382
+ width="80%",
383
+ max_width="800px"
384
+ ).render()
385
+ return result
386
+
387
+ def get_row_data(row_id):
388
+ index = int(row_id)
389
+ # print(app.state.config["providers"])
390
+ return app.state.config["providers"][index]
391
+
392
+ def update_row_data(row_id, updated_data):
393
+ print(row_id, updated_data)
394
+ index = int(row_id)
395
+ app.state.config["providers"][index] = updated_data
396
+ with open("./api1.yaml", "w", encoding="utf-8") as f:
397
+ yaml.dump(app.state.config, f)
398
+
399
+ @app.post("/submit/{row_id}", response_class=HTMLResponse)
400
+ async def submit_form(
401
+ row_id: str,
402
+ request: Request,
403
+ provider: str = FastapiForm(...),
404
+ base_url: str = FastapiForm(...),
405
+ api_key: Optional[str] = FastapiForm(None),
406
+ tools: Optional[str] = FastapiForm(None),
407
+ notes: Optional[str] = FastapiForm(None),
408
+ x_api_key: str = Depends(get_api_key)
409
+ ):
410
+ form_data = await request.form()
411
+
412
+ # 收集模型数据
413
+ models = []
414
+ for key, value in form_data.items():
415
+ if key.startswith("model_name_"):
416
+ model_id = key.split("_")[-1]
417
+ enabled = form_data.get(f"model_enabled_{model_id}") == "on"
418
+ rename = form_data.get(f"model_rename_{model_id}")
419
+ if value:
420
+ if rename:
421
+ models.append({value: rename})
422
+ else:
423
+ models.append(value)
424
+
425
+ updated_data = {
426
+ "provider": provider,
427
+ "base_url": base_url,
428
+ "api": api_key,
429
+ "model": models,
430
+ "tools": tools == "on",
431
+ "notes": notes,
432
+ }
433
+
434
+ print("updated_data", updated_data)
435
+
436
+ if row_id == "new":
437
+ # 添加新提供者
438
+ app.state.config["providers"].append(updated_data)
439
+ else:
440
+ # 更新现有提供者
441
+ update_row_data(row_id, updated_data)
442
+
443
+ # 保存更新后的配置
444
+ with open("./api1.yaml", "w", encoding="utf-8") as f:
445
+ yaml.dump(app.state.config, f)
446
+
447
+ return await root()
448
+
449
+ @app.post("/duplicate/{row_id}", response_class=HTMLResponse)
450
+ async def duplicate_row(row_id: str):
451
+ index = int(row_id)
452
+ original_data = app.state.config["providers"][index]
453
+ new_data = original_data.copy()
454
+ new_data["provider"] += "-copy"
455
+ app.state.config["providers"].insert(index + 1, new_data)
456
+
457
+ # 保存更新后的配置
458
+ with open("./api1.yaml", "w", encoding="utf-8") as f:
459
+ yaml.dump(app.state.config, f)
460
+
461
+ return await root()
462
+
463
+ @app.delete("/delete/{row_id}", response_class=HTMLResponse)
464
+ async def delete_row(row_id: str):
465
+ index = int(row_id)
466
+ del app.state.config["providers"][index]
467
+
468
+ # 保存更新后的配置
469
+ with open("./api1.yaml", "w", encoding="utf-8") as f:
470
+ yaml.dump(app.state.config, f)
471
+
472
+ return await root()
473
+
474
+ if __name__ == "__main__":
475
+ import uvicorn
476
+ uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
utils.py CHANGED
@@ -3,26 +3,74 @@ from fastapi import HTTPException
3
  import httpx
4
 
5
  from log_config import logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def update_config(config_data):
8
  for index, provider in enumerate(config_data['providers']):
9
- model_dict = {}
10
- for model in provider['model']:
11
- if type(model) == str:
12
- model_dict[model] = model
13
- if type(model) == dict:
14
- model_dict.update({new: old for old, new in model.items()})
15
- provider['model'] = model_dict
16
  if provider.get('project_id'):
17
  provider['base_url'] = 'https://aiplatform.googleapis.com/'
18
  if provider.get('cf_account_id'):
19
  provider['base_url'] = 'https://api.cloudflare.com/'
20
 
21
- if provider.get('api'):
22
- if isinstance(provider.get('api'), str):
23
- provider['api'] = CircularList([provider.get('api')])
24
- if isinstance(provider.get('api'), list):
25
- provider['api'] = CircularList(provider.get('api'))
 
 
26
 
27
  config_data['providers'][index] = provider
28
 
@@ -46,31 +94,24 @@ def update_config(config_data):
46
  api_keys_db[index]['model'] = models
47
 
48
  api_list = [item["api"] for item in api_keys_db]
49
- # logger.info(json.dumps(config_data, indent=4, ensure_ascii=False, default=circular_list_encoder))
50
  return config_data, api_keys_db, api_list
51
 
52
  # 读取YAML配置文件
53
  async def load_config(app=None):
54
- import yaml
55
  try:
56
- # with open('./api.yaml', 'r') as f:
57
- # tokens = yaml.scan(f)
58
- # for token in tokens:
59
- # if isinstance(token, yaml.ScalarToken):
60
- # value = token.value
61
- # # 如果plain为False,表示字符串被引号包裹
62
- # is_quoted = not token.plain
63
- # print(f"值: {value}, 是否被引号包裹: {is_quoted}")
64
-
65
- with open("./api.yaml", "r", encoding="utf-8") as f:
66
- # 判断是否为空文件
67
- conf = yaml.safe_load(f)
68
- # conf = None
69
- if conf:
70
- config, api_keys_db, api_list = update_config(conf)
71
- else:
72
- # logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
73
- config, api_keys_db, api_list = [], [], []
74
  except FileNotFoundError:
75
  logger.error("'api.yaml' not found. Please check the file path.")
76
  config, api_keys_db, api_list = [], [], []
@@ -228,7 +269,8 @@ def get_all_models(config):
228
  unique_models = set()
229
 
230
  for provider in config["providers"]:
231
- for model in provider['model'].keys():
 
232
  if model not in unique_models:
233
  unique_models.add(model)
234
  model_info = {
@@ -260,35 +302,12 @@ def get_all_models(config):
260
  # europe-west1
261
  # europe-west4
262
 
263
- def circular_list_encoder(obj):
264
- if isinstance(obj, CircularList):
265
- return obj.to_dict()
266
- raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
267
-
268
- from collections import deque
269
- class CircularList:
270
- def __init__(self, items):
271
- self.queue = deque(items)
272
-
273
- def next(self):
274
- if not self.queue:
275
- return None
276
- item = self.queue.popleft()
277
- self.queue.append(item)
278
- return item
279
-
280
- def to_dict(self):
281
- return {
282
- 'queue': list(self.queue)
283
- }
284
-
285
-
286
 
287
- c35s = CircularList(["us-east5", "europe-west1"])
288
- c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
289
- c3o = CircularList(["us-east5"])
290
- c3h = CircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
291
- gem = CircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
292
 
293
  class BaseAPI:
294
  def __init__(
 
3
  import httpx
4
 
5
  from log_config import logger
6
+ from collections import defaultdict
7
+
8
+ import asyncio
9
+
10
+ class ThreadSafeCircularList:
11
+ def __init__(self, items):
12
+ self.items = items
13
+ self.index = 0
14
+ self.lock = asyncio.Lock()
15
+
16
+ async def next(self):
17
+ async with self.lock:
18
+ item = self.items[self.index]
19
+ self.index = (self.index + 1) % len(self.items)
20
+ return item
21
+
22
+ def circular_list_encoder(obj):
23
+ if isinstance(obj, ThreadSafeCircularList):
24
+ return obj.to_dict()
25
+ raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
26
+
27
+ provider_api_circular_list = defaultdict(ThreadSafeCircularList)
28
+
29
+ def get_model_dict(provider):
30
+ model_dict = {}
31
+ for model in provider['model']:
32
+ if type(model) == str:
33
+ model_dict[model] = model
34
+ if type(model) == dict:
35
+ model_dict.update({new: old for old, new in model.items()})
36
+ return model_dict
37
+
38
+ def update_initial_model(api_url, api):
39
+ try:
40
+ endpoint = BaseAPI(api_url=api_url)
41
+ endpoint_models_url = endpoint.v1_models
42
+ response = httpx.get(
43
+ endpoint_models_url,
44
+ headers={"Authorization": f"Bearer {api}"},
45
+ )
46
+ models = response.json()
47
+ # print(models)
48
+ models_list = models["data"]
49
+ models_id = [model["id"] for model in models_list]
50
+ set_models = set()
51
+ for model_item in models_id:
52
+ set_models.add(model_item)
53
+ models_id = list(set_models)
54
+ # print(models_id)
55
+ return models_id
56
+ except Exception as e:
57
+ print("error:", e)
58
+ return []
59
 
60
  def update_config(config_data):
61
  for index, provider in enumerate(config_data['providers']):
 
 
 
 
 
 
 
62
  if provider.get('project_id'):
63
  provider['base_url'] = 'https://aiplatform.googleapis.com/'
64
  if provider.get('cf_account_id'):
65
  provider['base_url'] = 'https://api.cloudflare.com/'
66
 
67
+ provider_api_circular_list[provider['provider']] = ThreadSafeCircularList([provider.get('api', None)])
68
+
69
+ if not provider.get("model"):
70
+ provider["model"] = update_initial_model(provider['base_url'], provider['api'])
71
+
72
+ if provider.get("tools") == None:
73
+ provider["tools"] = True
74
 
75
  config_data['providers'][index] = provider
76
 
 
94
  api_keys_db[index]['model'] = models
95
 
96
  api_list = [item["api"] for item in api_keys_db]
97
+ # logger.info(json.dumps(config_data, indent=4, ensure_ascii=False))
98
  return config_data, api_keys_db, api_list
99
 
100
  # 读取YAML配置文件
101
  async def load_config(app=None):
 
102
  try:
103
+ from ruamel.yaml import YAML
104
+ yaml = YAML()
105
+ yaml.preserve_quotes = True
106
+ yaml.indent(mapping=2, sequence=4, offset=2)
107
+ with open('api.yaml', 'r', encoding='utf-8') as file:
108
+ conf = yaml.load(file)
109
+
110
+ if conf:
111
+ config, api_keys_db, api_list = update_config(conf)
112
+ else:
113
+ # logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
114
+ config, api_keys_db, api_list = [], [], []
 
 
 
 
 
 
115
  except FileNotFoundError:
116
  logger.error("'api.yaml' not found. Please check the file path.")
117
  config, api_keys_db, api_list = [], [], []
 
269
  unique_models = set()
270
 
271
  for provider in config["providers"]:
272
+ model_dict = get_model_dict(provider)
273
+ for model in model_dict.keys():
274
  if model not in unique_models:
275
  unique_models.add(model)
276
  model_info = {
 
302
  # europe-west1
303
  # europe-west4
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ c35s = ThreadSafeCircularList(["us-east5", "europe-west1"])
307
+ c3s = ThreadSafeCircularList(["us-east5", "us-central1", "asia-southeast1"])
308
+ c3o = ThreadSafeCircularList(["us-east5"])
309
+ c3h = ThreadSafeCircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
310
+ gem = ThreadSafeCircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
311
 
312
  class BaseAPI:
313
  def __init__(