✨ Feature: 1. Add support for experimental frontend.
Browse files2. Add support for configuration files without listing models.
- .gitignore +1 -1
- components/provider_table.py +230 -0
- main.py +24 -15
- models.py +10 -4
- request.py +42 -28
- requirements.txt +2 -1
- test/test_ruamel_yaml.py +44 -0
- test/xue/test_dropdown_sheet.py +75 -0
- test/xue/test_form_uni_api.py +116 -0
- test/xue/test_home.py +476 -0
- utils.py +80 -61
.gitignore
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
api.json
|
2 |
-
|
3 |
.env
|
4 |
__pycache__
|
5 |
.vscode
|
|
|
1 |
api.json
|
2 |
+
*.yaml
|
3 |
.env
|
4 |
__pycache__
|
5 |
.vscode
|
components/provider_table.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from xue import Div, Table, Thead, Tbody, Tr, Th, Td, Button, Input, Script, Head, Style, Span
|
2 |
+
from xue.components.checkbox import checkbox
|
3 |
+
from xue.components.dropdown import dropdown_menu, dropdown_menu_content
|
4 |
+
from xue.components.button import button
|
5 |
+
from xue.components.input import input
|
6 |
+
|
7 |
+
Head.add_default_children([
|
8 |
+
Style("""
|
9 |
+
.data-table-container {
|
10 |
+
width: 100%;
|
11 |
+
overflow-x: auto;
|
12 |
+
border: 1px solid #e2e8f0;
|
13 |
+
border-radius: 0.5rem;
|
14 |
+
overflow-x: visible !important;
|
15 |
+
}
|
16 |
+
.data-table {
|
17 |
+
width: 100%;
|
18 |
+
border-collapse: separate;
|
19 |
+
border-spacing: 0;
|
20 |
+
}
|
21 |
+
.data-table th, .data-table td {
|
22 |
+
padding: 0.75rem 1rem;
|
23 |
+
text-align: left;
|
24 |
+
border-bottom: 1px solid #e2e8f0;
|
25 |
+
}
|
26 |
+
.data-table th {
|
27 |
+
font-weight: 500;
|
28 |
+
font-size: 0.875rem;
|
29 |
+
color: #4b5563;
|
30 |
+
height: 2.5rem;
|
31 |
+
transition: background-color 0.2s;
|
32 |
+
}
|
33 |
+
.data-table thead tr:hover th,
|
34 |
+
.data-table tbody tr:hover {
|
35 |
+
background-color: #f8fafc;
|
36 |
+
}
|
37 |
+
.data-table tbody tr:last-child td {
|
38 |
+
border-bottom: none;
|
39 |
+
}
|
40 |
+
.sortable-header {
|
41 |
+
cursor: pointer;
|
42 |
+
user-select: none;
|
43 |
+
display: inline-flex;
|
44 |
+
align-items: center;
|
45 |
+
padding: 0.25rem 0.5rem;
|
46 |
+
border-radius: 0.25rem;
|
47 |
+
transition: background-color 0.2s;
|
48 |
+
}
|
49 |
+
.sortable-header:hover {
|
50 |
+
background-color: #e5e7eb;
|
51 |
+
}
|
52 |
+
.sort-icon {
|
53 |
+
display: inline-block;
|
54 |
+
width: 1rem;
|
55 |
+
height: 1rem;
|
56 |
+
margin-left: 0.25rem;
|
57 |
+
transition: transform 0.2s;
|
58 |
+
opacity: 0;
|
59 |
+
}
|
60 |
+
.sortable-header:hover .sort-icon,
|
61 |
+
.sort-asc .sort-icon,
|
62 |
+
.sort-desc .sort-icon {
|
63 |
+
opacity: 1;
|
64 |
+
}
|
65 |
+
.sort-asc .sort-icon {
|
66 |
+
transform: rotate(180deg);
|
67 |
+
}
|
68 |
+
.table-header {
|
69 |
+
display: flex;
|
70 |
+
justify-content: space-between;
|
71 |
+
align-items: center;
|
72 |
+
margin-bottom: 1rem;
|
73 |
+
}
|
74 |
+
.table-footer {
|
75 |
+
display: flex;
|
76 |
+
justify-content: space-between;
|
77 |
+
align-items: center;
|
78 |
+
margin-top: 1rem;
|
79 |
+
}
|
80 |
+
.pagination {
|
81 |
+
display: flex;
|
82 |
+
gap: 0.5rem;
|
83 |
+
}
|
84 |
+
@media (prefers-color-scheme: dark) {
|
85 |
+
.data-table-container {
|
86 |
+
border-color: #4b5563;
|
87 |
+
}
|
88 |
+
.data-table th, .data-table td {
|
89 |
+
border-color: #4b5563;
|
90 |
+
}
|
91 |
+
.data-table th {
|
92 |
+
color: #d1d5db;
|
93 |
+
}
|
94 |
+
.data-table thead tr:hover th,
|
95 |
+
.data-table tbody tr:hover {
|
96 |
+
background-color: #1f2937;
|
97 |
+
}
|
98 |
+
.sortable-header:hover {
|
99 |
+
background-color: #374151;
|
100 |
+
}
|
101 |
+
}
|
102 |
+
""", id="data-table-style"),
|
103 |
+
Script("""
|
104 |
+
function toggleAllRows(checked) {
|
105 |
+
const checkboxes = document.querySelectorAll('.row-checkbox');
|
106 |
+
checkboxes.forEach(cb => cb.checked = checked);
|
107 |
+
updateSelectedCount();
|
108 |
+
}
|
109 |
+
|
110 |
+
function updateSelectedCount() {
|
111 |
+
const selectedCount = document.querySelectorAll('.row-checkbox:checked').length;
|
112 |
+
const totalCount = document.querySelectorAll('.row-checkbox').length;
|
113 |
+
document.getElementById('selected-count').textContent = `${selectedCount} of ${totalCount} row(s) selected.`;
|
114 |
+
}
|
115 |
+
|
116 |
+
function sortTable(columnIndex, accessor) {
|
117 |
+
const table = document.querySelector('.data-table');
|
118 |
+
const header = table.querySelector(`th[data-accessor="${accessor}"]`);
|
119 |
+
const isAscending = !header.classList.contains('sort-asc');
|
120 |
+
|
121 |
+
// Update sort direction
|
122 |
+
table.querySelectorAll('th').forEach(th => th.classList.remove('sort-asc', 'sort-desc'));
|
123 |
+
header.classList.add(isAscending ? 'sort-asc' : 'sort-desc');
|
124 |
+
|
125 |
+
// Sort the table
|
126 |
+
const rows = Array.from(table.querySelectorAll('tbody tr'));
|
127 |
+
rows.sort((a, b) => {
|
128 |
+
const aValue = a.querySelector(`td[data-accessor="${accessor}"]`).textContent;
|
129 |
+
const bValue = b.querySelector(`td[data-accessor="${accessor}"]`).textContent;
|
130 |
+
return isAscending ? aValue.localeCompare(bValue) : bValue.localeCompare(aValue);
|
131 |
+
});
|
132 |
+
|
133 |
+
// Update the table
|
134 |
+
const tbody = table.querySelector('tbody');
|
135 |
+
rows.forEach(row => tbody.appendChild(row));
|
136 |
+
}
|
137 |
+
|
138 |
+
document.addEventListener('change', function(event) {
|
139 |
+
if (event.target.classList.contains('row-checkbox')) {
|
140 |
+
updateSelectedCount();
|
141 |
+
}
|
142 |
+
});
|
143 |
+
""", id="data-table-script"),
|
144 |
+
])
|
145 |
+
|
146 |
+
def data_table(columns, data, id, with_filter=True):
|
147 |
+
return Div(
|
148 |
+
Div(
|
149 |
+
input(type="text", placeholder="Filter...", id=f"{id}-filter", class_="mr-auto"),
|
150 |
+
Div(
|
151 |
+
button(
|
152 |
+
"Add Provider",
|
153 |
+
variant="secondary",
|
154 |
+
hx_get="/add-provider-sheet",
|
155 |
+
hx_target="#sheet-container",
|
156 |
+
hx_swap="innerHTML",
|
157 |
+
class_="h-[2.625rem]"
|
158 |
+
),
|
159 |
+
dropdown_menu("Columns"),
|
160 |
+
),
|
161 |
+
class_="table-header flex items-center"
|
162 |
+
) if with_filter else None,
|
163 |
+
Div(
|
164 |
+
Div(
|
165 |
+
Table(
|
166 |
+
Thead(
|
167 |
+
Tr(
|
168 |
+
Th(checkbox("select-all", "", onclick="toggleAllRows(this.checked)")),
|
169 |
+
*[Th(
|
170 |
+
Div(
|
171 |
+
col['label'],
|
172 |
+
Span("▼", class_="sort-icon"),
|
173 |
+
class_="sortable-header" if col.get('sortable', False) else "",
|
174 |
+
onclick=f"sortTable({i}, '{col['value']}')" if col.get('sortable', False) else None
|
175 |
+
),
|
176 |
+
data_accessor=col['value']
|
177 |
+
) for i, col in enumerate(columns)],
|
178 |
+
Th("Actions") # 新增的操作列
|
179 |
+
)
|
180 |
+
),
|
181 |
+
Tbody(
|
182 |
+
*[Tr(
|
183 |
+
Td(checkbox(f"row-{i}", "", class_="row-checkbox")),
|
184 |
+
*[Td(row[col['value']], data_accessor=col['value']) for col in columns],
|
185 |
+
Td(row_actions_menu(i)), # 使用行索引作为 row_id
|
186 |
+
id=f"row-{i}"
|
187 |
+
) for i, row in enumerate(data)]
|
188 |
+
),
|
189 |
+
class_="data-table"
|
190 |
+
),
|
191 |
+
class_="data-table-container"
|
192 |
+
),
|
193 |
+
Div(
|
194 |
+
Div(id="selected-count", class_="text-sm text-gray-500"),
|
195 |
+
Div(
|
196 |
+
button("Previous", variant="outline", class_="mr-2"),
|
197 |
+
button("Next", variant="outline"),
|
198 |
+
class_="pagination"
|
199 |
+
),
|
200 |
+
class_="table-footer"
|
201 |
+
),
|
202 |
+
id=id
|
203 |
+
),
|
204 |
+
)
|
205 |
+
|
206 |
+
def get_column_visibility_menu(id, columns):
|
207 |
+
return dropdown_menu_content(id, [
|
208 |
+
{"label": col['label'], "value": col['value']}
|
209 |
+
for col in columns if col.get('can_hide', True)
|
210 |
+
])
|
211 |
+
|
212 |
+
def row_actions_menu(row_id):
|
213 |
+
return dropdown_menu("⋮", id=f"row-actions-menu-{row_id}", hx_get=f"/dropdown-menu/dropdown-menu-⋮/{row_id}")
|
214 |
+
|
215 |
+
def get_row_actions_menu(row_id):
|
216 |
+
return dropdown_menu_content(f"row-actions-{row_id}", [
|
217 |
+
{"label": "Edit", "icon": "pencil"},
|
218 |
+
{"label": "Duplicate", "icon": "copy"},
|
219 |
+
{"label": "Delete", "icon": "trash"},
|
220 |
+
"separator",
|
221 |
+
{"label": "More...", "icon": "more-horizontal"},
|
222 |
+
])
|
223 |
+
|
224 |
+
def render_row(row_data, row_id, columns):
|
225 |
+
return Tr(
|
226 |
+
Td(checkbox(f"row-{row_id}", "", class_="row-checkbox")),
|
227 |
+
*[Td(row_data[col['value']], data_accessor=col['value']) for col in columns],
|
228 |
+
Td(row_actions_menu(row_id)),
|
229 |
+
id=f"row-{row_id}"
|
230 |
+
).render()
|
main.py
CHANGED
@@ -18,7 +18,7 @@ from fastapi.exceptions import RequestValidationError
|
|
18 |
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
19 |
from request import get_payload
|
20 |
from response import fetch_response, fetch_response_stream
|
21 |
-
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
22 |
|
23 |
from collections import defaultdict
|
24 |
from typing import List, Dict, Union
|
@@ -492,20 +492,21 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
492 |
else:
|
493 |
engine = "gpt"
|
494 |
|
495 |
-
|
496 |
-
|
497 |
-
and "
|
|
|
498 |
and parsed_url.netloc != 'api.cloudflare.com' \
|
499 |
and parsed_url.netloc != 'api.cohere.com':
|
500 |
engine = "openrouter"
|
501 |
|
502 |
-
if "claude" in
|
503 |
engine = "vertex-claude"
|
504 |
|
505 |
-
if "gemini" in
|
506 |
engine = "vertex-gemini"
|
507 |
|
508 |
-
if "o1-preview" in
|
509 |
engine = "o1"
|
510 |
request.stream = False
|
511 |
|
@@ -536,7 +537,7 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
536 |
current_info = request_info.get()
|
537 |
try:
|
538 |
if request.stream:
|
539 |
-
model =
|
540 |
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
541 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
542 |
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
|
@@ -603,7 +604,8 @@ class ModelRequestHandler:
|
|
603 |
if model == "all":
|
604 |
# 如果模型名为 *,则返回所有模型
|
605 |
for provider in config["providers"]:
|
606 |
-
|
|
|
607 |
provider_rules.append(provider["provider"] + "/" + model)
|
608 |
break
|
609 |
if "/" in model:
|
@@ -611,15 +613,17 @@ class ModelRequestHandler:
|
|
611 |
model = model[1:-1]
|
612 |
# 处理带斜杠的模型名
|
613 |
for provider in config['providers']:
|
614 |
-
|
|
|
615 |
provider_rules.append(provider['provider'] + "/" + model)
|
616 |
else:
|
617 |
provider_name = model.split("/")[0]
|
618 |
model_name_split = "/".join(model.split("/")[1:])
|
619 |
models_list = []
|
620 |
for provider in config['providers']:
|
|
|
621 |
if provider['provider'] == provider_name:
|
622 |
-
models_list.extend(list(
|
623 |
# print("models_list", models_list)
|
624 |
# print("model_name", model_name)
|
625 |
# print("model_name_split", model_name_split)
|
@@ -632,7 +636,8 @@ class ModelRequestHandler:
|
|
632 |
provider_rules.append(provider_name)
|
633 |
else:
|
634 |
for provider in config['providers']:
|
635 |
-
|
|
|
636 |
provider_rules.append(provider['provider'] + "/" + model)
|
637 |
|
638 |
provider_list = []
|
@@ -642,10 +647,13 @@ class ModelRequestHandler:
|
|
642 |
# print("provider", provider, provider['provider'] == item, item)
|
643 |
if "/" in item:
|
644 |
if provider['provider'] == item.split("/")[0]:
|
645 |
-
|
|
|
646 |
provider_list.append(provider)
|
|
|
647 |
elif provider['provider'] == item:
|
648 |
-
|
|
|
649 |
provider_list.append(provider)
|
650 |
else:
|
651 |
pass
|
@@ -655,7 +663,8 @@ class ModelRequestHandler:
|
|
655 |
# if item.split("/")[1] == model_name:
|
656 |
# provider_list.append(provider)
|
657 |
# else:
|
658 |
-
#
|
|
|
659 |
# provider_list.append(provider)
|
660 |
if is_debug:
|
661 |
for provider in provider_list:
|
|
|
18 |
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
19 |
from request import get_payload
|
20 |
from response import fetch_response, fetch_response_stream
|
21 |
+
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict
|
22 |
|
23 |
from collections import defaultdict
|
24 |
from typing import List, Dict, Union
|
|
|
492 |
else:
|
493 |
engine = "gpt"
|
494 |
|
495 |
+
model_dict = get_model_dict(provider)
|
496 |
+
if "claude" not in model_dict[request.model] \
|
497 |
+
and "gpt" not in model_dict[request.model] \
|
498 |
+
and "gemini" not in model_dict[request.model] \
|
499 |
and parsed_url.netloc != 'api.cloudflare.com' \
|
500 |
and parsed_url.netloc != 'api.cohere.com':
|
501 |
engine = "openrouter"
|
502 |
|
503 |
+
if "claude" in model_dict[request.model] and engine == "vertex":
|
504 |
engine = "vertex-claude"
|
505 |
|
506 |
+
if "gemini" in model_dict[request.model] and engine == "vertex":
|
507 |
engine = "vertex-gemini"
|
508 |
|
509 |
+
if "o1-preview" in model_dict[request.model] or "o1-mini" in model_dict[request.model]:
|
510 |
engine = "o1"
|
511 |
request.stream = False
|
512 |
|
|
|
537 |
current_info = request_info.get()
|
538 |
try:
|
539 |
if request.stream:
|
540 |
+
model = model_dict[request.model]
|
541 |
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
|
542 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
543 |
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
|
|
|
604 |
if model == "all":
|
605 |
# 如果模型名为 *,则返回所有模型
|
606 |
for provider in config["providers"]:
|
607 |
+
model_dict = get_model_dict(provider)
|
608 |
+
for model in model_dict.keys():
|
609 |
provider_rules.append(provider["provider"] + "/" + model)
|
610 |
break
|
611 |
if "/" in model:
|
|
|
613 |
model = model[1:-1]
|
614 |
# 处理带斜杠的模型名
|
615 |
for provider in config['providers']:
|
616 |
+
model_dict = get_model_dict(provider)
|
617 |
+
if model in model_dict.keys():
|
618 |
provider_rules.append(provider['provider'] + "/" + model)
|
619 |
else:
|
620 |
provider_name = model.split("/")[0]
|
621 |
model_name_split = "/".join(model.split("/")[1:])
|
622 |
models_list = []
|
623 |
for provider in config['providers']:
|
624 |
+
model_dict = get_model_dict(provider)
|
625 |
if provider['provider'] == provider_name:
|
626 |
+
models_list.extend(list(model_dict.keys()))
|
627 |
# print("models_list", models_list)
|
628 |
# print("model_name", model_name)
|
629 |
# print("model_name_split", model_name_split)
|
|
|
636 |
provider_rules.append(provider_name)
|
637 |
else:
|
638 |
for provider in config['providers']:
|
639 |
+
model_dict = get_model_dict(provider)
|
640 |
+
if model in model_dict.keys():
|
641 |
provider_rules.append(provider['provider'] + "/" + model)
|
642 |
|
643 |
provider_list = []
|
|
|
647 |
# print("provider", provider, provider['provider'] == item, item)
|
648 |
if "/" in item:
|
649 |
if provider['provider'] == item.split("/")[0]:
|
650 |
+
model_dict = get_model_dict(provider)
|
651 |
+
if model_name in model_dict.keys() and "/".join(item.split("/")[1:]) == model_name:
|
652 |
provider_list.append(provider)
|
653 |
+
# 如果 item 不包含 /,则直接匹配 provider,说明整个渠道所有模型都能用
|
654 |
elif provider['provider'] == item:
|
655 |
+
model_dict = get_model_dict(provider)
|
656 |
+
if model_name in model_dict.keys():
|
657 |
provider_list.append(provider)
|
658 |
else:
|
659 |
pass
|
|
|
663 |
# if item.split("/")[1] == model_name:
|
664 |
# provider_list.append(provider)
|
665 |
# else:
|
666 |
+
# model_dict = get_model_dict(provider)
|
667 |
+
# if model_name in model_dict.keys():
|
668 |
# provider_list.append(provider)
|
669 |
if is_debug:
|
670 |
for provider in provider_list:
|
models.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from io import IOBase
|
2 |
-
from pydantic import BaseModel, Field, model_validator
|
3 |
from typing import List, Dict, Optional, Union, Tuple, Literal, Any
|
4 |
from log_config import logger
|
5 |
|
@@ -61,10 +61,16 @@ class ToolChoice(BaseModel):
|
|
61 |
class BaseRequest(BaseModel):
|
62 |
request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
67 |
|
|
|
68 |
class ResponseFormat(BaseModel):
|
69 |
type: Literal["text", "json_object", "json_schema"]
|
70 |
json_schema: Optional[JsonSchema] = None
|
|
|
1 |
from io import IOBase
|
2 |
+
from pydantic import BaseModel, Field, model_validator, ConfigDict
|
3 |
from typing import List, Dict, Optional, Union, Tuple, Literal, Any
|
4 |
from log_config import logger
|
5 |
|
|
|
61 |
class BaseRequest(BaseModel):
|
62 |
request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True)
|
63 |
|
64 |
+
def create_json_schema_class():
|
65 |
+
class JsonSchema(BaseModel):
|
66 |
+
name: str
|
67 |
+
|
68 |
+
model_config = ConfigDict(protected_namespaces=())
|
69 |
+
|
70 |
+
JsonSchema.__annotations__['schema'] = Dict[str, Any]
|
71 |
+
return JsonSchema
|
72 |
|
73 |
+
JsonSchema = create_json_schema_class()
|
74 |
class ResponseFormat(BaseModel):
|
75 |
type: Literal["text", "json_object", "json_schema"]
|
76 |
json_schema: Optional[JsonSchema] = None
|
request.py
CHANGED
@@ -6,7 +6,7 @@ import base64
|
|
6 |
import urllib.parse
|
7 |
|
8 |
from models import RequestModel
|
9 |
-
from utils import c35s, c3s, c3o, c3h, gem, BaseAPI
|
10 |
|
11 |
import imghdr
|
12 |
|
@@ -120,13 +120,14 @@ async def get_gemini_payload(request, engine, provider):
|
|
120 |
headers = {
|
121 |
'Content-Type': 'application/json'
|
122 |
}
|
123 |
-
|
|
|
124 |
gemini_stream = "streamGenerateContent"
|
125 |
url = provider['base_url']
|
126 |
if url.endswith("v1beta"):
|
127 |
-
url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['
|
128 |
if url.endswith("v1"):
|
129 |
-
url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=provider['
|
130 |
|
131 |
messages = []
|
132 |
systemInstruction = None
|
@@ -312,7 +313,8 @@ async def get_vertex_gemini_payload(request, engine, provider):
|
|
312 |
project_id = provider.get("project_id")
|
313 |
|
314 |
gemini_stream = "streamGenerateContent"
|
315 |
-
|
|
|
316 |
location = gem
|
317 |
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream)
|
318 |
|
@@ -449,7 +451,8 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
449 |
if provider.get("project_id"):
|
450 |
project_id = provider.get("project_id")
|
451 |
|
452 |
-
|
|
|
453 |
if "claude-3-5-sonnet" in model:
|
454 |
location = c35s
|
455 |
elif "claude-3-opus" in model:
|
@@ -460,7 +463,7 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
460 |
location = c3h
|
461 |
|
462 |
claude_stream = "streamRawPredict"
|
463 |
-
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
|
464 |
|
465 |
messages = []
|
466 |
system_prompt = None
|
@@ -534,7 +537,8 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
534 |
else:
|
535 |
message_index = message_index + 1
|
536 |
|
537 |
-
|
|
|
538 |
payload = {
|
539 |
"anthropic_version": "vertex-2023-10-16",
|
540 |
"messages": messages,
|
@@ -593,7 +597,7 @@ async def get_gpt_payload(request, engine, provider):
|
|
593 |
'Content-Type': 'application/json',
|
594 |
}
|
595 |
if provider.get("api"):
|
596 |
-
headers['Authorization'] = f"Bearer {provider['
|
597 |
url = provider['base_url']
|
598 |
|
599 |
messages = []
|
@@ -633,7 +637,8 @@ async def get_gpt_payload(request, engine, provider):
|
|
633 |
else:
|
634 |
messages.append({"role": msg.role, "content": content})
|
635 |
|
636 |
-
|
|
|
637 |
payload = {
|
638 |
"model": model,
|
639 |
"messages": messages,
|
@@ -659,7 +664,7 @@ async def get_openrouter_payload(request, engine, provider):
|
|
659 |
'Content-Type': 'application/json'
|
660 |
}
|
661 |
if provider.get("api"):
|
662 |
-
headers['Authorization'] = f"Bearer {provider['
|
663 |
|
664 |
url = provider['base_url']
|
665 |
|
@@ -691,7 +696,8 @@ async def get_openrouter_payload(request, engine, provider):
|
|
691 |
else:
|
692 |
messages.append({"role": msg.role, "content": content})
|
693 |
|
694 |
-
|
|
|
695 |
payload = {
|
696 |
"model": model,
|
697 |
"messages": messages,
|
@@ -725,7 +731,7 @@ async def get_cohere_payload(request, engine, provider):
|
|
725 |
'Content-Type': 'application/json'
|
726 |
}
|
727 |
if provider.get("api"):
|
728 |
-
headers['Authorization'] = f"Bearer {provider['
|
729 |
|
730 |
url = provider['base_url']
|
731 |
|
@@ -753,7 +759,8 @@ async def get_cohere_payload(request, engine, provider):
|
|
753 |
else:
|
754 |
messages.append({"role": role_map[msg.role], "message": content})
|
755 |
|
756 |
-
|
|
|
757 |
chat_history = messages[:-1]
|
758 |
query = messages[-1].get("message")
|
759 |
payload = {
|
@@ -792,9 +799,10 @@ async def get_cloudflare_payload(request, engine, provider):
|
|
792 |
'Content-Type': 'application/json'
|
793 |
}
|
794 |
if provider.get("api"):
|
795 |
-
headers['Authorization'] = f"Bearer {provider['
|
796 |
|
797 |
-
|
|
|
798 |
url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
|
799 |
|
800 |
msg = request.messages[-1]
|
@@ -808,7 +816,7 @@ async def get_cloudflare_payload(request, engine, provider):
|
|
808 |
content = msg.content
|
809 |
name = msg.name
|
810 |
|
811 |
-
model =
|
812 |
payload = {
|
813 |
"prompt": content,
|
814 |
}
|
@@ -841,7 +849,7 @@ async def get_o1_payload(request, engine, provider):
|
|
841 |
'Content-Type': 'application/json'
|
842 |
}
|
843 |
if provider.get("api"):
|
844 |
-
headers['Authorization'] = f"Bearer {provider['
|
845 |
|
846 |
url = provider['base_url']
|
847 |
|
@@ -863,7 +871,8 @@ async def get_o1_payload(request, engine, provider):
|
|
863 |
elif msg.role != "system":
|
864 |
messages.append({"role": msg.role, "content": content})
|
865 |
|
866 |
-
|
|
|
867 |
payload = {
|
868 |
"model": model,
|
869 |
"messages": messages,
|
@@ -912,10 +921,11 @@ async def gpt2claude_tools_json(json_dict):
|
|
912 |
return json_dict
|
913 |
|
914 |
async def get_claude_payload(request, engine, provider):
|
915 |
-
|
|
|
916 |
headers = {
|
917 |
"content-type": "application/json",
|
918 |
-
"x-api-key": f"{provider['
|
919 |
"anthropic-version": "2023-06-01",
|
920 |
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
|
921 |
}
|
@@ -993,7 +1003,8 @@ async def get_claude_payload(request, engine, provider):
|
|
993 |
else:
|
994 |
message_index = message_index + 1
|
995 |
|
996 |
-
|
|
|
997 |
payload = {
|
998 |
"model": model,
|
999 |
"messages": messages,
|
@@ -1051,12 +1062,13 @@ async def get_claude_payload(request, engine, provider):
|
|
1051 |
return url, headers, payload
|
1052 |
|
1053 |
async def get_dalle_payload(request, engine, provider):
|
1054 |
-
|
|
|
1055 |
headers = {
|
1056 |
"Content-Type": "application/json",
|
1057 |
}
|
1058 |
if provider.get("api"):
|
1059 |
-
headers['Authorization'] = f"Bearer {provider['
|
1060 |
url = provider['base_url']
|
1061 |
url = BaseAPI(url).image_url
|
1062 |
|
@@ -1070,12 +1082,13 @@ async def get_dalle_payload(request, engine, provider):
|
|
1070 |
return url, headers, payload
|
1071 |
|
1072 |
async def get_whisper_payload(request, engine, provider):
|
1073 |
-
|
|
|
1074 |
headers = {
|
1075 |
# "Content-Type": "multipart/form-data",
|
1076 |
}
|
1077 |
if provider.get("api"):
|
1078 |
-
headers['Authorization'] = f"Bearer {provider['
|
1079 |
url = provider['base_url']
|
1080 |
url = BaseAPI(url).audio_transcriptions
|
1081 |
|
@@ -1096,12 +1109,13 @@ async def get_whisper_payload(request, engine, provider):
|
|
1096 |
return url, headers, payload
|
1097 |
|
1098 |
async def get_moderation_payload(request, engine, provider):
|
1099 |
-
|
|
|
1100 |
headers = {
|
1101 |
"Content-Type": "application/json",
|
1102 |
}
|
1103 |
if provider.get("api"):
|
1104 |
-
headers['Authorization'] = f"Bearer {provider['
|
1105 |
url = provider['base_url']
|
1106 |
url = BaseAPI(url).moderations
|
1107 |
|
|
|
6 |
import urllib.parse
|
7 |
|
8 |
from models import RequestModel
|
9 |
+
from utils import c35s, c3s, c3o, c3h, gem, BaseAPI, get_model_dict, provider_api_circular_list
|
10 |
|
11 |
import imghdr
|
12 |
|
|
|
120 |
headers = {
|
121 |
'Content-Type': 'application/json'
|
122 |
}
|
123 |
+
model_dict = get_model_dict(provider)
|
124 |
+
model = model_dict[request.model]
|
125 |
gemini_stream = "streamGenerateContent"
|
126 |
url = provider['base_url']
|
127 |
if url.endswith("v1beta"):
|
128 |
+
url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
|
129 |
if url.endswith("v1"):
|
130 |
+
url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
|
131 |
|
132 |
messages = []
|
133 |
systemInstruction = None
|
|
|
313 |
project_id = provider.get("project_id")
|
314 |
|
315 |
gemini_stream = "streamGenerateContent"
|
316 |
+
model_dict = get_model_dict(provider)
|
317 |
+
model = model_dict[request.model]
|
318 |
location = gem
|
319 |
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream)
|
320 |
|
|
|
451 |
if provider.get("project_id"):
|
452 |
project_id = provider.get("project_id")
|
453 |
|
454 |
+
model_dict = get_model_dict(provider)
|
455 |
+
model = model_dict[request.model]
|
456 |
if "claude-3-5-sonnet" in model:
|
457 |
location = c35s
|
458 |
elif "claude-3-opus" in model:
|
|
|
463 |
location = c3h
|
464 |
|
465 |
claude_stream = "streamRawPredict"
|
466 |
+
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=await location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
|
467 |
|
468 |
messages = []
|
469 |
system_prompt = None
|
|
|
537 |
else:
|
538 |
message_index = message_index + 1
|
539 |
|
540 |
+
model_dict = get_model_dict(provider)
|
541 |
+
model = model_dict[request.model]
|
542 |
payload = {
|
543 |
"anthropic_version": "vertex-2023-10-16",
|
544 |
"messages": messages,
|
|
|
597 |
'Content-Type': 'application/json',
|
598 |
}
|
599 |
if provider.get("api"):
|
600 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
601 |
url = provider['base_url']
|
602 |
|
603 |
messages = []
|
|
|
637 |
else:
|
638 |
messages.append({"role": msg.role, "content": content})
|
639 |
|
640 |
+
model_dict = get_model_dict(provider)
|
641 |
+
model = model_dict[request.model]
|
642 |
payload = {
|
643 |
"model": model,
|
644 |
"messages": messages,
|
|
|
664 |
'Content-Type': 'application/json'
|
665 |
}
|
666 |
if provider.get("api"):
|
667 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
668 |
|
669 |
url = provider['base_url']
|
670 |
|
|
|
696 |
else:
|
697 |
messages.append({"role": msg.role, "content": content})
|
698 |
|
699 |
+
model_dict = get_model_dict(provider)
|
700 |
+
model = model_dict[request.model]
|
701 |
payload = {
|
702 |
"model": model,
|
703 |
"messages": messages,
|
|
|
731 |
'Content-Type': 'application/json'
|
732 |
}
|
733 |
if provider.get("api"):
|
734 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
735 |
|
736 |
url = provider['base_url']
|
737 |
|
|
|
759 |
else:
|
760 |
messages.append({"role": role_map[msg.role], "message": content})
|
761 |
|
762 |
+
model_dict = get_model_dict(provider)
|
763 |
+
model = model_dict[request.model]
|
764 |
chat_history = messages[:-1]
|
765 |
query = messages[-1].get("message")
|
766 |
payload = {
|
|
|
799 |
'Content-Type': 'application/json'
|
800 |
}
|
801 |
if provider.get("api"):
|
802 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
803 |
|
804 |
+
model_dict = get_model_dict(provider)
|
805 |
+
model = model_dict[request.model]
|
806 |
url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
|
807 |
|
808 |
msg = request.messages[-1]
|
|
|
816 |
content = msg.content
|
817 |
name = msg.name
|
818 |
|
819 |
+
model = model_dict[request.model]
|
820 |
payload = {
|
821 |
"prompt": content,
|
822 |
}
|
|
|
849 |
'Content-Type': 'application/json'
|
850 |
}
|
851 |
if provider.get("api"):
|
852 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
853 |
|
854 |
url = provider['base_url']
|
855 |
|
|
|
871 |
elif msg.role != "system":
|
872 |
messages.append({"role": msg.role, "content": content})
|
873 |
|
874 |
+
model_dict = get_model_dict(provider)
|
875 |
+
model = model_dict[request.model]
|
876 |
payload = {
|
877 |
"model": model,
|
878 |
"messages": messages,
|
|
|
921 |
return json_dict
|
922 |
|
923 |
async def get_claude_payload(request, engine, provider):
|
924 |
+
model_dict = get_model_dict(provider)
|
925 |
+
model = model_dict[request.model]
|
926 |
headers = {
|
927 |
"content-type": "application/json",
|
928 |
+
"x-api-key": f"{await provider_api_circular_list[provider['provider']].next()}",
|
929 |
"anthropic-version": "2023-06-01",
|
930 |
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
|
931 |
}
|
|
|
1003 |
else:
|
1004 |
message_index = message_index + 1
|
1005 |
|
1006 |
+
model_dict = get_model_dict(provider)
|
1007 |
+
model = model_dict[request.model]
|
1008 |
payload = {
|
1009 |
"model": model,
|
1010 |
"messages": messages,
|
|
|
1062 |
return url, headers, payload
|
1063 |
|
1064 |
async def get_dalle_payload(request, engine, provider):
|
1065 |
+
model_dict = get_model_dict(provider)
|
1066 |
+
model = model_dict[request.model]
|
1067 |
headers = {
|
1068 |
"Content-Type": "application/json",
|
1069 |
}
|
1070 |
if provider.get("api"):
|
1071 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
1072 |
url = provider['base_url']
|
1073 |
url = BaseAPI(url).image_url
|
1074 |
|
|
|
1082 |
return url, headers, payload
|
1083 |
|
1084 |
async def get_whisper_payload(request, engine, provider):
|
1085 |
+
model_dict = get_model_dict(provider)
|
1086 |
+
model = model_dict[request.model]
|
1087 |
headers = {
|
1088 |
# "Content-Type": "multipart/form-data",
|
1089 |
}
|
1090 |
if provider.get("api"):
|
1091 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
1092 |
url = provider['base_url']
|
1093 |
url = BaseAPI(url).audio_transcriptions
|
1094 |
|
|
|
1109 |
return url, headers, payload
|
1110 |
|
1111 |
async def get_moderation_payload(request, engine, provider):
|
1112 |
+
model_dict = get_model_dict(provider)
|
1113 |
+
model = model_dict[request.model]
|
1114 |
headers = {
|
1115 |
"Content-Type": "application/json",
|
1116 |
}
|
1117 |
if provider.get("api"):
|
1118 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
1119 |
url = provider['base_url']
|
1120 |
url = BaseAPI(url).moderations
|
1121 |
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
pytest
|
3 |
uvicorn
|
4 |
fastapi
|
@@ -7,6 +7,7 @@ greenlet
|
|
7 |
aiosqlite
|
8 |
sqlalchemy
|
9 |
watchfiles
|
|
|
10 |
httpx[http2]
|
11 |
cryptography
|
12 |
python-multipart
|
|
|
1 |
+
xue
|
2 |
pytest
|
3 |
uvicorn
|
4 |
fastapi
|
|
|
7 |
aiosqlite
|
8 |
sqlalchemy
|
9 |
watchfiles
|
10 |
+
ruamel.yaml
|
11 |
httpx[http2]
|
12 |
cryptography
|
13 |
python-multipart
|
test/test_ruamel_yaml.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ruamel.yaml import YAML
|
2 |
+
|
3 |
+
# 假设我们有以下 YAML 内容
|
4 |
+
yaml_content = """
|
5 |
+
# 这是顶级注释
|
6 |
+
key1: value1 # 行尾注释
|
7 |
+
key2: value2
|
8 |
+
|
9 |
+
# 这是嵌套结构的注释
|
10 |
+
nested:
|
11 |
+
subkey1: subvalue1
|
12 |
+
subkey2: subvalue2 # 嵌套的行尾注释
|
13 |
+
|
14 |
+
# 列表的注释
|
15 |
+
list_key:
|
16 |
+
- item1
|
17 |
+
- item2 # 列表项的注释
|
18 |
+
"""
|
19 |
+
|
20 |
+
# 创建 YAML 对象
|
21 |
+
yaml = YAML()
|
22 |
+
yaml.preserve_quotes = True
|
23 |
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
24 |
+
|
25 |
+
with open('api.yaml', 'r', encoding='utf-8') as file:
|
26 |
+
data = yaml.load(file)
|
27 |
+
|
28 |
+
# data = yaml.load(yaml_content)
|
29 |
+
# 加载 YAML 数据
|
30 |
+
print(data)
|
31 |
+
|
32 |
+
# # 修改数据
|
33 |
+
# data['key1'] = 'new_value1'
|
34 |
+
# data['nested']['subkey1'] = 'new_subvalue1'
|
35 |
+
# data['list_key'].append('new_item')
|
36 |
+
|
37 |
+
# 将修改后的数据写回文件(这里我们使用 StringIO 来模拟文件操作)
|
38 |
+
# from io import StringIO
|
39 |
+
# output = StringIO()
|
40 |
+
# yaml.dump(data, output)
|
41 |
+
# print(output.getvalue())
|
42 |
+
|
43 |
+
with open('formatted.yaml', 'w', encoding='utf-8') as file:
|
44 |
+
yaml.dump(data, file)
|
test/xue/test_dropdown_sheet.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from fastapi.responses import HTMLResponse
|
3 |
+
from xue import HTML, Head, Body, Div, xue_initialize, Script
|
4 |
+
from xue.components import dropdown, sheet, button, form, input
|
5 |
+
|
6 |
+
xue_initialize(tailwind=True)
|
7 |
+
app = FastAPI()
|
8 |
+
|
9 |
+
@app.get("/", response_class=HTMLResponse)
|
10 |
+
async def root():
|
11 |
+
result = HTML(
|
12 |
+
Head(
|
13 |
+
title="Dropdown with Edit Sheet Example",
|
14 |
+
),
|
15 |
+
Body(
|
16 |
+
Div(
|
17 |
+
dropdown.dropdown_menu("Actions"),
|
18 |
+
Div(id="sheet-container"), # 这里是 sheet 将被加载的地方
|
19 |
+
class_="container mx-auto p-4"
|
20 |
+
)
|
21 |
+
)
|
22 |
+
).render()
|
23 |
+
print(result)
|
24 |
+
return result
|
25 |
+
|
26 |
+
@app.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse)
|
27 |
+
async def get_dropdown_menu_content(menu_id: str):
|
28 |
+
items = [
|
29 |
+
{
|
30 |
+
"icon": "pencil",
|
31 |
+
"label": "Edit",
|
32 |
+
"hx-get": "/edit-sheet",
|
33 |
+
"hx-target": "#sheet-container",
|
34 |
+
"hx-swap": "innerHTML"
|
35 |
+
},
|
36 |
+
{"icon": "trash", "label": "Delete"},
|
37 |
+
{"icon": "copy", "label": "Duplicate"},
|
38 |
+
]
|
39 |
+
result = dropdown.dropdown_menu_content(menu_id, items).render()
|
40 |
+
print("dropdown-menu result", result)
|
41 |
+
return result
|
42 |
+
|
43 |
+
@app.get("/edit-sheet", response_class=HTMLResponse)
|
44 |
+
async def get_edit_sheet():
|
45 |
+
edit_sheet_content = sheet.SheetContent(
|
46 |
+
sheet.SheetHeader(
|
47 |
+
sheet.SheetTitle("Edit Item"),
|
48 |
+
sheet.SheetDescription("Make changes to your item here.")
|
49 |
+
),
|
50 |
+
sheet.SheetBody(
|
51 |
+
form.Form(
|
52 |
+
form.FormField("Name", "name", placeholder="Enter item name"),
|
53 |
+
form.FormField("Description", "description", placeholder="Enter item description"),
|
54 |
+
Div(
|
55 |
+
button.button("Save", class_="bg-blue-500 text-white"),
|
56 |
+
button.button("Cancel", class_="bg-gray-300 text-gray-700 ml-2", data_close_sheet="true"),
|
57 |
+
class_="flex justify-end mt-4"
|
58 |
+
),
|
59 |
+
class_="space-y-4"
|
60 |
+
)
|
61 |
+
)
|
62 |
+
)
|
63 |
+
|
64 |
+
result = sheet.Sheet(
|
65 |
+
"edit-sheet",
|
66 |
+
Div(),
|
67 |
+
edit_sheet_content,
|
68 |
+
width="80%",
|
69 |
+
max_width="800px"
|
70 |
+
).render()
|
71 |
+
return result
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
import uvicorn
|
75 |
+
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|
test/xue/test_form_uni_api.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Form as FastAPIForm
|
2 |
+
from fastapi.responses import HTMLResponse
|
3 |
+
from xue import HTML, Head, Body, Div, xue_initialize, Strong, Span, Ul, Li
|
4 |
+
from xue.components import form, button, checkbox, input
|
5 |
+
from xue.components.model_config_row import model_config_row
|
6 |
+
from typing import List, Optional
|
7 |
+
import time
|
8 |
+
|
9 |
+
xue_initialize(tailwind=True)
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
@app.get("/", response_class=HTMLResponse)
|
13 |
+
async def root():
|
14 |
+
result = HTML(
|
15 |
+
Head(
|
16 |
+
title="Provider Configuration Form"
|
17 |
+
),
|
18 |
+
Body(
|
19 |
+
Div(
|
20 |
+
form.Form(
|
21 |
+
form.FormField("Provider", "provider", placeholder="Enter provider name", required=True),
|
22 |
+
form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True),
|
23 |
+
form.FormField("API Key", "api_key", type="password", placeholder="Enter API key"),
|
24 |
+
Div(
|
25 |
+
Div("Models", class_="text-lg font-semibold mb-2"),
|
26 |
+
Div(
|
27 |
+
model_config_row("model1", "gpt-4o: deepbricks-gpt-4o-mini", True),
|
28 |
+
model_config_row("model2", "gpt-4o"),
|
29 |
+
model_config_row("model3", "gpt-3.5-turbo"),
|
30 |
+
model_config_row("model4", "claude-3-5-sonnet-20240620: claude-3-5-sonnet"),
|
31 |
+
model_config_row("model5", "o1-mini-all"),
|
32 |
+
model_config_row("model6", "o1-preview-all"),
|
33 |
+
model_config_row("model7", "whisper-1"),
|
34 |
+
id="models-container"
|
35 |
+
),
|
36 |
+
button.button(
|
37 |
+
"Add Model",
|
38 |
+
class_="mt-2",
|
39 |
+
hx_post="/add-model",
|
40 |
+
hx_target="#models-container",
|
41 |
+
hx_swap="beforeend"
|
42 |
+
),
|
43 |
+
class_="mb-4"
|
44 |
+
),
|
45 |
+
Div(
|
46 |
+
checkbox.checkbox("tools", "Enable Tools", checked=True),
|
47 |
+
class_="mb-4"
|
48 |
+
),
|
49 |
+
form.FormField("Notes", "notes", placeholder="Enter any additional notes"),
|
50 |
+
Div(
|
51 |
+
button.button("Submit", class_="bg-blue-500 text-white"),
|
52 |
+
button.button("Cancel", class_="bg-gray-300 text-gray-700 ml-2"),
|
53 |
+
class_="flex justify-end mt-4"
|
54 |
+
),
|
55 |
+
hx_post="/submit",
|
56 |
+
hx_swap="outerHTML",
|
57 |
+
class_="space-y-4"
|
58 |
+
),
|
59 |
+
class_="container mx-auto p-4 max-w-2xl"
|
60 |
+
)
|
61 |
+
)
|
62 |
+
).render()
|
63 |
+
print(result)
|
64 |
+
return result
|
65 |
+
|
66 |
+
@app.post("/add-model", response_class=HTMLResponse)
|
67 |
+
async def add_model():
|
68 |
+
new_model_id = f"model{hash(str(time.time()))}" # 生成一个唯一的ID
|
69 |
+
new_model = model_config_row(new_model_id).render()
|
70 |
+
return new_model
|
71 |
+
|
72 |
+
def form_success_message(provider, base_url, api_key, models, tools_enabled, notes):
|
73 |
+
return Div(
|
74 |
+
Strong("Success!", class_="font-bold"),
|
75 |
+
Span("Form submitted successfully.", class_="block sm:inline"),
|
76 |
+
Ul(
|
77 |
+
Li(f"Provider: {provider}"),
|
78 |
+
Li(f"Base URL: {base_url}"),
|
79 |
+
Li(f"API Key: {'*' * len(api_key)}"),
|
80 |
+
Li(f"Models: {', '.join(models)}"),
|
81 |
+
Li(f"Tools Enabled: {'Yes' if tools_enabled else 'No'}"),
|
82 |
+
Li(f"Notes: {notes}"),
|
83 |
+
class_="mt-3"
|
84 |
+
),
|
85 |
+
class_="bg-green-100 border border-green-400 text-green-700 px-4 py-3 rounded relative",
|
86 |
+
role="alert"
|
87 |
+
)
|
88 |
+
|
89 |
+
@app.post("/submit", response_class=HTMLResponse)
|
90 |
+
async def submit_form(
|
91 |
+
provider: str = FastAPIForm(...),
|
92 |
+
base_url: str = FastAPIForm(...),
|
93 |
+
api_key: str = FastAPIForm(...),
|
94 |
+
models: List[str] = FastAPIForm([]),
|
95 |
+
tools: Optional[str] = FastAPIForm(None),
|
96 |
+
notes: Optional[str] = FastAPIForm(None)
|
97 |
+
):
|
98 |
+
# 处理提交的数据
|
99 |
+
print(f"Received: provider={provider}, base_url={base_url}, api_key={api_key}")
|
100 |
+
print(f"Models: {models}")
|
101 |
+
print(f"Tools Enabled: {tools is not None}")
|
102 |
+
print(f"Notes: {notes}")
|
103 |
+
|
104 |
+
# 返回处理结果
|
105 |
+
return form_success_message(
|
106 |
+
provider,
|
107 |
+
base_url,
|
108 |
+
api_key,
|
109 |
+
models,
|
110 |
+
tools is not None,
|
111 |
+
notes or "No notes provided"
|
112 |
+
).render()
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
import uvicorn
|
116 |
+
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|
test/xue/test_home.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request
|
2 |
+
from fastapi import Form as FastapiForm, HTTPException, Depends
|
3 |
+
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
4 |
+
from fastapi.security import APIKeyHeader
|
5 |
+
from typing import Optional, List
|
6 |
+
|
7 |
+
from xue import HTML, Head, Body, Div, xue_initialize, Script
|
8 |
+
from xue.components.menubar import (
|
9 |
+
Menubar, MenubarMenu, MenubarTrigger, MenubarContent,
|
10 |
+
MenubarItem, MenubarSeparator
|
11 |
+
)
|
12 |
+
from xue.components import input
|
13 |
+
from xue.components import dropdown, sheet, form, button, checkbox
|
14 |
+
from xue.components.model_config_row import model_config_row
|
15 |
+
import time
|
16 |
+
|
17 |
+
import sys
|
18 |
+
import os
|
19 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
20 |
+
from components.provider_table import data_table
|
21 |
+
|
22 |
+
|
23 |
+
from ruamel.yaml import YAML
|
24 |
+
yaml = YAML()
|
25 |
+
yaml.preserve_quotes = True
|
26 |
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
27 |
+
|
28 |
+
xue_initialize(tailwind=True)
|
29 |
+
|
30 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
31 |
+
import logging
|
32 |
+
|
33 |
+
logging.basicConfig(level=logging.INFO)
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
class RequestBodyLoggerMiddleware(BaseHTTPMiddleware):
|
37 |
+
async def dispatch(self, request: Request, call_next):
|
38 |
+
if request.method == "POST" and request.url.path.startswith("/submit/"):
|
39 |
+
# if request.method == "POST":
|
40 |
+
body = await request.body()
|
41 |
+
logger.info(f"Request body for {request.url.path}: {body.decode()}")
|
42 |
+
|
43 |
+
response = await call_next(request)
|
44 |
+
return response
|
45 |
+
|
46 |
+
from utils import load_config
|
47 |
+
from contextlib import asynccontextmanager
|
48 |
+
@asynccontextmanager
|
49 |
+
async def lifespan(app: FastAPI):
|
50 |
+
# app.state.client = httpx.AsyncClient(timeout=timeout)
|
51 |
+
app.state.config, app.state.api_keys_db, app.state.api_list = await load_config()
|
52 |
+
for item in app.state.api_keys_db:
|
53 |
+
if item.get("role") == "admin":
|
54 |
+
app.state.admin_api_key = item.get("api")
|
55 |
+
if not hasattr(app.state, "admin_api_key"):
|
56 |
+
if len(app.state.api_keys_db) >= 1:
|
57 |
+
app.state.admin_api_key = app.state.api_keys_db[0].get("api")
|
58 |
+
else:
|
59 |
+
raise Exception("No admin API key found")
|
60 |
+
|
61 |
+
global data
|
62 |
+
# providers_data = app.state.config["providers"]
|
63 |
+
|
64 |
+
# print("data", data)
|
65 |
+
yield
|
66 |
+
# 关闭时的代码
|
67 |
+
await app.state.client.aclose()
|
68 |
+
|
69 |
+
app = FastAPI(lifespan=lifespan)
|
70 |
+
# app.add_middleware(RequestBodyLoggerMiddleware)
|
71 |
+
app.add_middleware(RequestBodyLoggerMiddleware)
|
72 |
+
|
73 |
+
data_table_columns = [
|
74 |
+
# {"label": "Status", "value": "status", "sortable": True},
|
75 |
+
{"label": "Provider", "value": "provider", "sortable": True},
|
76 |
+
{"label": "Base url", "value": "base_url", "sortable": True},
|
77 |
+
# {"label": "Engine", "value": "engine", "sortable": True},
|
78 |
+
{"label": "Tools", "value": "tools", "sortable": True},
|
79 |
+
]
|
80 |
+
|
81 |
+
API_KEY_NAME = "X-API-Key"
|
82 |
+
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
83 |
+
|
84 |
+
@app.get("/login", response_class=HTMLResponse)
|
85 |
+
async def login_page():
|
86 |
+
return HTML(
|
87 |
+
Head(title="登录"),
|
88 |
+
Body(
|
89 |
+
Div(
|
90 |
+
form.Form(
|
91 |
+
form.FormField("API Key", "x_api_key", type="password", placeholder="输入API密钥", required=True),
|
92 |
+
Div(id="error-message", class_="text-red-500 mt-2"),
|
93 |
+
Div(
|
94 |
+
button.button("提交", variant="primary", type="submit"),
|
95 |
+
class_="flex justify-end mt-4"
|
96 |
+
),
|
97 |
+
hx_post="/verify-api-key",
|
98 |
+
hx_target="#error-message",
|
99 |
+
hx_swap="innerHTML",
|
100 |
+
class_="space-y-4"
|
101 |
+
),
|
102 |
+
class_="container mx-auto p-4 max-w-md"
|
103 |
+
)
|
104 |
+
)
|
105 |
+
).render()
|
106 |
+
|
107 |
+
|
108 |
+
@app.post("/verify-api-key", response_class=HTMLResponse)
|
109 |
+
async def verify_api_key(x_api_key: str = FastapiForm(...)):
|
110 |
+
if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥
|
111 |
+
response = JSONResponse(content={"success": True})
|
112 |
+
response.headers["HX-Redirect"] = "/" # 添加这一行
|
113 |
+
response.set_cookie(
|
114 |
+
key="x_api_key",
|
115 |
+
value=x_api_key,
|
116 |
+
httponly=True,
|
117 |
+
max_age=1800, # 30分钟
|
118 |
+
secure=False, # 在开发环境中设置为False,生产环境中使用HTTPS时设置为True
|
119 |
+
samesite="lax" # 改为"lax"以允许重定向时携带cookie
|
120 |
+
)
|
121 |
+
return response
|
122 |
+
else:
|
123 |
+
return Div("无效的API密钥", class_="text-red-500").render()
|
124 |
+
|
125 |
+
async def get_api_key(request: Request, x_api_key: Optional[str] = Depends(api_key_header)):
|
126 |
+
if not x_api_key:
|
127 |
+
x_api_key = request.cookies.get("x_api_key") or request.query_params.get("x_api_key")
|
128 |
+
# print(f"Cookie x_api_key: {request.cookies.get('x_api_key')}") # 添加此行
|
129 |
+
# print(f"Query param x_api_key: {request.query_params.get('x_api_key')}") # 添加此行
|
130 |
+
# print(f"Header x_api_key: {x_api_key}") # 添加此��
|
131 |
+
# logger.info(f"x_api_key: {x_api_key} {x_api_key == 'your_admin_api_key'}")
|
132 |
+
|
133 |
+
if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥
|
134 |
+
return x_api_key
|
135 |
+
else:
|
136 |
+
return None
|
137 |
+
|
138 |
+
@app.get("/", response_class=HTMLResponse)
|
139 |
+
async def root(x_api_key: str = Depends(get_api_key)):
|
140 |
+
if not x_api_key:
|
141 |
+
return RedirectResponse(url="/login", status_code=303)
|
142 |
+
|
143 |
+
result = HTML(
|
144 |
+
Head(
|
145 |
+
Script("""
|
146 |
+
document.addEventListener('DOMContentLoaded', function() {
|
147 |
+
const filterInput = document.getElementById('users-table-filter');
|
148 |
+
filterInput.addEventListener('input', function() {
|
149 |
+
const filterValue = this.value;
|
150 |
+
htmx.ajax('GET', `/filter-table?filter=${filterValue}`, '#users-table');
|
151 |
+
});
|
152 |
+
});
|
153 |
+
"""),
|
154 |
+
title="Menubar Example"
|
155 |
+
),
|
156 |
+
Body(
|
157 |
+
Div(
|
158 |
+
Menubar(
|
159 |
+
MenubarMenu(
|
160 |
+
MenubarTrigger("File", "file-menu"),
|
161 |
+
MenubarContent(
|
162 |
+
MenubarItem("New Tab", shortcut="⌘T"),
|
163 |
+
MenubarItem("New Window", shortcut="⌘N"),
|
164 |
+
MenubarItem("New Incognito Window", disabled=True),
|
165 |
+
MenubarSeparator(),
|
166 |
+
MenubarItem("Print...", shortcut="⌘P"),
|
167 |
+
),
|
168 |
+
id="file-menu"
|
169 |
+
),
|
170 |
+
MenubarMenu(
|
171 |
+
MenubarTrigger("Edit", "edit-menu"),
|
172 |
+
MenubarContent(
|
173 |
+
MenubarItem("Undo", shortcut="⌘Z"),
|
174 |
+
MenubarItem("Redo", shortcut="⇧⌘Z"),
|
175 |
+
MenubarSeparator(),
|
176 |
+
MenubarItem("Cut"),
|
177 |
+
MenubarItem("Copy"),
|
178 |
+
MenubarItem("Paste"),
|
179 |
+
),
|
180 |
+
id="edit-menu"
|
181 |
+
),
|
182 |
+
MenubarMenu(
|
183 |
+
MenubarTrigger("View", "view-menu"),
|
184 |
+
MenubarContent(
|
185 |
+
MenubarItem("Always Show Bookmarks Bar"),
|
186 |
+
MenubarItem("Always Show Full URLs"),
|
187 |
+
MenubarSeparator(),
|
188 |
+
MenubarItem("Reload", shortcut="⌘R"),
|
189 |
+
MenubarItem("Force Reload", shortcut="⇧⌘R", disabled=True),
|
190 |
+
MenubarSeparator(),
|
191 |
+
MenubarItem("Toggle Fullscreen"),
|
192 |
+
MenubarItem("Hide Sidebar"),
|
193 |
+
),
|
194 |
+
id="view-menu"
|
195 |
+
),
|
196 |
+
),
|
197 |
+
class_="p-4"
|
198 |
+
),
|
199 |
+
Div(
|
200 |
+
data_table(data_table_columns, app.state.config["providers"], "users-table"),
|
201 |
+
class_="p-4"
|
202 |
+
),
|
203 |
+
Div(id="sheet-container"), # 这里是 sheet 将被加载的地方
|
204 |
+
class_="container mx-auto",
|
205 |
+
id="body"
|
206 |
+
)
|
207 |
+
).render()
|
208 |
+
# print(result)
|
209 |
+
return result
|
210 |
+
|
211 |
+
@app.get("/dropdown-menu/{menu_id}/{row_id}", response_class=HTMLResponse)
|
212 |
+
async def get_columns_menu(menu_id: str, row_id: str):
|
213 |
+
columns = [
|
214 |
+
{
|
215 |
+
"label": "Edit",
|
216 |
+
"value": "edit",
|
217 |
+
"hx-get": f"/edit-sheet/{row_id}",
|
218 |
+
"hx-target": "#sheet-container",
|
219 |
+
"hx-swap": "innerHTML"
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"label": "Duplicate",
|
223 |
+
"value": "duplicate",
|
224 |
+
"hx-post": f"/duplicate/{row_id}",
|
225 |
+
"hx-target": "body",
|
226 |
+
"hx-swap": "outerHTML"
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"label": "Delete",
|
230 |
+
"value": "delete",
|
231 |
+
"hx-delete": f"/delete/{row_id}",
|
232 |
+
"hx-target": "body",
|
233 |
+
"hx-swap": "outerHTML",
|
234 |
+
"hx-confirm": "确定要删除这个配置吗?"
|
235 |
+
},
|
236 |
+
]
|
237 |
+
result = dropdown.dropdown_menu_content(menu_id, columns).render()
|
238 |
+
print(result)
|
239 |
+
return result
|
240 |
+
|
241 |
+
@app.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse)
|
242 |
+
async def get_columns_menu(menu_id: str):
|
243 |
+
result = dropdown.dropdown_menu_content(menu_id, data_table_columns).render()
|
244 |
+
print(result)
|
245 |
+
return result
|
246 |
+
|
247 |
+
@app.get("/filter-table", response_class=HTMLResponse)
|
248 |
+
async def filter_table(filter: str = ""):
|
249 |
+
filtered_data = [
|
250 |
+
provider for provider in app.state.config["providers"]
|
251 |
+
if filter.lower() in str(provider["provider"]).lower() or
|
252 |
+
filter.lower() in str(provider["base_url"]).lower() or
|
253 |
+
filter.lower() in str(provider["tools"]).lower()
|
254 |
+
]
|
255 |
+
return data_table(data_table_columns, filtered_data, "users-table", with_filter=False).render()
|
256 |
+
|
257 |
+
@app.post("/add-model", response_class=HTMLResponse)
|
258 |
+
async def add_model():
|
259 |
+
new_model_id = f"model{hash(str(time.time()))}" # 生成一个唯一的ID
|
260 |
+
new_model = model_config_row(new_model_id).render()
|
261 |
+
return new_model
|
262 |
+
|
263 |
+
@app.get("/edit-sheet/{row_id}", response_class=HTMLResponse)
|
264 |
+
async def get_edit_sheet(row_id: str, x_api_key: str = Depends(get_api_key)):
|
265 |
+
row_data = get_row_data(row_id)
|
266 |
+
print("row_data", row_data)
|
267 |
+
|
268 |
+
model_list = []
|
269 |
+
for index, model in enumerate(row_data["model"]):
|
270 |
+
if isinstance(model, str):
|
271 |
+
model_list.append(model_config_row(f"model{index}", model, "", True))
|
272 |
+
if isinstance(model, dict):
|
273 |
+
# print("model", model, list(model.items())[0])
|
274 |
+
key, value = list(model.items())[0]
|
275 |
+
model_list.append(model_config_row(f"model{index}", key, value, True))
|
276 |
+
|
277 |
+
sheet_id = "edit-sheet"
|
278 |
+
edit_sheet_content = sheet.SheetContent(
|
279 |
+
sheet.SheetHeader(
|
280 |
+
sheet.SheetTitle("Edit Item"),
|
281 |
+
sheet.SheetDescription("Make changes to your item here.")
|
282 |
+
),
|
283 |
+
sheet.SheetBody(
|
284 |
+
Div(
|
285 |
+
form.Form(
|
286 |
+
form.FormField("Provider", "provider", value=row_data["provider"], placeholder="Enter provider name", required=True),
|
287 |
+
form.FormField("Base URL", "base_url", value=row_data["base_url"], placeholder="Enter base URL", required=True),
|
288 |
+
form.FormField("API Key", "api_key", value=row_data["api"], type="text", placeholder="Enter API key"),
|
289 |
+
Div(
|
290 |
+
Div("Models", class_="text-lg font-semibold mb-2"),
|
291 |
+
Div(
|
292 |
+
*model_list,
|
293 |
+
id="models-container"
|
294 |
+
),
|
295 |
+
button.button(
|
296 |
+
"Add Model",
|
297 |
+
class_="mt-2",
|
298 |
+
hx_post="/add-model",
|
299 |
+
hx_target="#models-container",
|
300 |
+
hx_swap="beforeend"
|
301 |
+
),
|
302 |
+
class_="mb-4"
|
303 |
+
),
|
304 |
+
Div(
|
305 |
+
checkbox.checkbox("tools", "Enable Tools", checked=row_data["tools"], name="tools"),
|
306 |
+
class_="mb-4"
|
307 |
+
),
|
308 |
+
form.FormField("Notes", "notes", value=row_data.get("notes", ""), placeholder="Enter any additional notes"),
|
309 |
+
Div(
|
310 |
+
button.button("Submit", variant="primary", type="submit"),
|
311 |
+
button.button("Cancel", variant="outline", type="button", class_="ml-2", onclick=f"toggleSheet('{sheet_id}')"),
|
312 |
+
class_="flex justify-end mt-4"
|
313 |
+
),
|
314 |
+
hx_post=f"/submit/{row_id}",
|
315 |
+
hx_swap="outerHTML",
|
316 |
+
hx_target="body",
|
317 |
+
class_="space-y-4"
|
318 |
+
),
|
319 |
+
class_="container mx-auto p-4 max-w-2xl"
|
320 |
+
)
|
321 |
+
)
|
322 |
+
)
|
323 |
+
|
324 |
+
result = sheet.Sheet(
|
325 |
+
sheet_id,
|
326 |
+
Div(),
|
327 |
+
edit_sheet_content,
|
328 |
+
width="80%",
|
329 |
+
max_width="800px"
|
330 |
+
).render()
|
331 |
+
return result
|
332 |
+
|
333 |
+
@app.get("/add-provider-sheet", response_class=HTMLResponse)
|
334 |
+
async def get_add_provider_sheet():
|
335 |
+
edit_sheet_content = sheet.SheetContent(
|
336 |
+
sheet.SheetHeader(
|
337 |
+
sheet.SheetTitle("Add New Provider"),
|
338 |
+
sheet.SheetDescription("Enter details for the new provider.")
|
339 |
+
),
|
340 |
+
sheet.SheetBody(
|
341 |
+
Div(
|
342 |
+
form.Form(
|
343 |
+
form.FormField("Provider", "provider", placeholder="Enter provider name", required=True),
|
344 |
+
form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True),
|
345 |
+
form.FormField("API Key", "api_key", type="text", placeholder="Enter API key"),
|
346 |
+
Div(
|
347 |
+
Div("Models", class_="text-lg font-semibold mb-2"),
|
348 |
+
Div(id="models-container"),
|
349 |
+
button.button(
|
350 |
+
"Add Model",
|
351 |
+
class_="mt-2",
|
352 |
+
hx_post="/add-model",
|
353 |
+
hx_target="#models-container",
|
354 |
+
hx_swap="beforeend"
|
355 |
+
),
|
356 |
+
class_="mb-4"
|
357 |
+
),
|
358 |
+
Div(
|
359 |
+
checkbox.checkbox("tools", "Enable Tools", name="tools"),
|
360 |
+
class_="mb-4"
|
361 |
+
),
|
362 |
+
form.FormField("Notes", "notes", placeholder="Enter any additional notes"),
|
363 |
+
Div(
|
364 |
+
button.button("Submit", variant="primary", type="submit"),
|
365 |
+
button.button("Cancel", variant="outline", class_="ml-2"),
|
366 |
+
class_="flex justify-end mt-4"
|
367 |
+
),
|
368 |
+
hx_post="/submit/new",
|
369 |
+
hx_swap="outerHTML",
|
370 |
+
hx_target="body",
|
371 |
+
class_="space-y-4"
|
372 |
+
),
|
373 |
+
class_="container mx-auto p-4 max-w-2xl"
|
374 |
+
)
|
375 |
+
)
|
376 |
+
)
|
377 |
+
|
378 |
+
result = sheet.Sheet(
|
379 |
+
"add-provider-sheet",
|
380 |
+
Div(),
|
381 |
+
edit_sheet_content,
|
382 |
+
width="80%",
|
383 |
+
max_width="800px"
|
384 |
+
).render()
|
385 |
+
return result
|
386 |
+
|
387 |
+
def get_row_data(row_id):
|
388 |
+
index = int(row_id)
|
389 |
+
# print(app.state.config["providers"])
|
390 |
+
return app.state.config["providers"][index]
|
391 |
+
|
392 |
+
def update_row_data(row_id, updated_data):
|
393 |
+
print(row_id, updated_data)
|
394 |
+
index = int(row_id)
|
395 |
+
app.state.config["providers"][index] = updated_data
|
396 |
+
with open("./api1.yaml", "w", encoding="utf-8") as f:
|
397 |
+
yaml.dump(app.state.config, f)
|
398 |
+
|
399 |
+
@app.post("/submit/{row_id}", response_class=HTMLResponse)
|
400 |
+
async def submit_form(
|
401 |
+
row_id: str,
|
402 |
+
request: Request,
|
403 |
+
provider: str = FastapiForm(...),
|
404 |
+
base_url: str = FastapiForm(...),
|
405 |
+
api_key: Optional[str] = FastapiForm(None),
|
406 |
+
tools: Optional[str] = FastapiForm(None),
|
407 |
+
notes: Optional[str] = FastapiForm(None),
|
408 |
+
x_api_key: str = Depends(get_api_key)
|
409 |
+
):
|
410 |
+
form_data = await request.form()
|
411 |
+
|
412 |
+
# 收集模型数据
|
413 |
+
models = []
|
414 |
+
for key, value in form_data.items():
|
415 |
+
if key.startswith("model_name_"):
|
416 |
+
model_id = key.split("_")[-1]
|
417 |
+
enabled = form_data.get(f"model_enabled_{model_id}") == "on"
|
418 |
+
rename = form_data.get(f"model_rename_{model_id}")
|
419 |
+
if value:
|
420 |
+
if rename:
|
421 |
+
models.append({value: rename})
|
422 |
+
else:
|
423 |
+
models.append(value)
|
424 |
+
|
425 |
+
updated_data = {
|
426 |
+
"provider": provider,
|
427 |
+
"base_url": base_url,
|
428 |
+
"api": api_key,
|
429 |
+
"model": models,
|
430 |
+
"tools": tools == "on",
|
431 |
+
"notes": notes,
|
432 |
+
}
|
433 |
+
|
434 |
+
print("updated_data", updated_data)
|
435 |
+
|
436 |
+
if row_id == "new":
|
437 |
+
# 添加新提供者
|
438 |
+
app.state.config["providers"].append(updated_data)
|
439 |
+
else:
|
440 |
+
# 更新现有提供者
|
441 |
+
update_row_data(row_id, updated_data)
|
442 |
+
|
443 |
+
# 保存更新后的配置
|
444 |
+
with open("./api1.yaml", "w", encoding="utf-8") as f:
|
445 |
+
yaml.dump(app.state.config, f)
|
446 |
+
|
447 |
+
return await root()
|
448 |
+
|
449 |
+
@app.post("/duplicate/{row_id}", response_class=HTMLResponse)
|
450 |
+
async def duplicate_row(row_id: str):
|
451 |
+
index = int(row_id)
|
452 |
+
original_data = app.state.config["providers"][index]
|
453 |
+
new_data = original_data.copy()
|
454 |
+
new_data["provider"] += "-copy"
|
455 |
+
app.state.config["providers"].insert(index + 1, new_data)
|
456 |
+
|
457 |
+
# 保存更新后的配置
|
458 |
+
with open("./api1.yaml", "w", encoding="utf-8") as f:
|
459 |
+
yaml.dump(app.state.config, f)
|
460 |
+
|
461 |
+
return await root()
|
462 |
+
|
463 |
+
@app.delete("/delete/{row_id}", response_class=HTMLResponse)
|
464 |
+
async def delete_row(row_id: str):
|
465 |
+
index = int(row_id)
|
466 |
+
del app.state.config["providers"][index]
|
467 |
+
|
468 |
+
# 保存更新后的配置
|
469 |
+
with open("./api1.yaml", "w", encoding="utf-8") as f:
|
470 |
+
yaml.dump(app.state.config, f)
|
471 |
+
|
472 |
+
return await root()
|
473 |
+
|
474 |
+
if __name__ == "__main__":
|
475 |
+
import uvicorn
|
476 |
+
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|
utils.py
CHANGED
@@ -3,26 +3,74 @@ from fastapi import HTTPException
|
|
3 |
import httpx
|
4 |
|
5 |
from log_config import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def update_config(config_data):
|
8 |
for index, provider in enumerate(config_data['providers']):
|
9 |
-
model_dict = {}
|
10 |
-
for model in provider['model']:
|
11 |
-
if type(model) == str:
|
12 |
-
model_dict[model] = model
|
13 |
-
if type(model) == dict:
|
14 |
-
model_dict.update({new: old for old, new in model.items()})
|
15 |
-
provider['model'] = model_dict
|
16 |
if provider.get('project_id'):
|
17 |
provider['base_url'] = 'https://aiplatform.googleapis.com/'
|
18 |
if provider.get('cf_account_id'):
|
19 |
provider['base_url'] = 'https://api.cloudflare.com/'
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
26 |
|
27 |
config_data['providers'][index] = provider
|
28 |
|
@@ -46,31 +94,24 @@ def update_config(config_data):
|
|
46 |
api_keys_db[index]['model'] = models
|
47 |
|
48 |
api_list = [item["api"] for item in api_keys_db]
|
49 |
-
# logger.info(json.dumps(config_data, indent=4, ensure_ascii=False
|
50 |
return config_data, api_keys_db, api_list
|
51 |
|
52 |
# 读取YAML配置文件
|
53 |
async def load_config(app=None):
|
54 |
-
import yaml
|
55 |
try:
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
#
|
67 |
-
|
68 |
-
# conf = None
|
69 |
-
if conf:
|
70 |
-
config, api_keys_db, api_list = update_config(conf)
|
71 |
-
else:
|
72 |
-
# logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
|
73 |
-
config, api_keys_db, api_list = [], [], []
|
74 |
except FileNotFoundError:
|
75 |
logger.error("'api.yaml' not found. Please check the file path.")
|
76 |
config, api_keys_db, api_list = [], [], []
|
@@ -228,7 +269,8 @@ def get_all_models(config):
|
|
228 |
unique_models = set()
|
229 |
|
230 |
for provider in config["providers"]:
|
231 |
-
|
|
|
232 |
if model not in unique_models:
|
233 |
unique_models.add(model)
|
234 |
model_info = {
|
@@ -260,35 +302,12 @@ def get_all_models(config):
|
|
260 |
# europe-west1
|
261 |
# europe-west4
|
262 |
|
263 |
-
def circular_list_encoder(obj):
|
264 |
-
if isinstance(obj, CircularList):
|
265 |
-
return obj.to_dict()
|
266 |
-
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
|
267 |
-
|
268 |
-
from collections import deque
|
269 |
-
class CircularList:
|
270 |
-
def __init__(self, items):
|
271 |
-
self.queue = deque(items)
|
272 |
-
|
273 |
-
def next(self):
|
274 |
-
if not self.queue:
|
275 |
-
return None
|
276 |
-
item = self.queue.popleft()
|
277 |
-
self.queue.append(item)
|
278 |
-
return item
|
279 |
-
|
280 |
-
def to_dict(self):
|
281 |
-
return {
|
282 |
-
'queue': list(self.queue)
|
283 |
-
}
|
284 |
-
|
285 |
-
|
286 |
|
287 |
-
c35s =
|
288 |
-
c3s =
|
289 |
-
c3o =
|
290 |
-
c3h =
|
291 |
-
gem =
|
292 |
|
293 |
class BaseAPI:
|
294 |
def __init__(
|
|
|
3 |
import httpx
|
4 |
|
5 |
from log_config import logger
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
import asyncio
|
9 |
+
|
10 |
+
class ThreadSafeCircularList:
|
11 |
+
def __init__(self, items):
|
12 |
+
self.items = items
|
13 |
+
self.index = 0
|
14 |
+
self.lock = asyncio.Lock()
|
15 |
+
|
16 |
+
async def next(self):
|
17 |
+
async with self.lock:
|
18 |
+
item = self.items[self.index]
|
19 |
+
self.index = (self.index + 1) % len(self.items)
|
20 |
+
return item
|
21 |
+
|
22 |
+
def circular_list_encoder(obj):
|
23 |
+
if isinstance(obj, ThreadSafeCircularList):
|
24 |
+
return obj.to_dict()
|
25 |
+
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
|
26 |
+
|
27 |
+
provider_api_circular_list = defaultdict(ThreadSafeCircularList)
|
28 |
+
|
29 |
+
def get_model_dict(provider):
|
30 |
+
model_dict = {}
|
31 |
+
for model in provider['model']:
|
32 |
+
if type(model) == str:
|
33 |
+
model_dict[model] = model
|
34 |
+
if type(model) == dict:
|
35 |
+
model_dict.update({new: old for old, new in model.items()})
|
36 |
+
return model_dict
|
37 |
+
|
38 |
+
def update_initial_model(api_url, api):
|
39 |
+
try:
|
40 |
+
endpoint = BaseAPI(api_url=api_url)
|
41 |
+
endpoint_models_url = endpoint.v1_models
|
42 |
+
response = httpx.get(
|
43 |
+
endpoint_models_url,
|
44 |
+
headers={"Authorization": f"Bearer {api}"},
|
45 |
+
)
|
46 |
+
models = response.json()
|
47 |
+
# print(models)
|
48 |
+
models_list = models["data"]
|
49 |
+
models_id = [model["id"] for model in models_list]
|
50 |
+
set_models = set()
|
51 |
+
for model_item in models_id:
|
52 |
+
set_models.add(model_item)
|
53 |
+
models_id = list(set_models)
|
54 |
+
# print(models_id)
|
55 |
+
return models_id
|
56 |
+
except Exception as e:
|
57 |
+
print("error:", e)
|
58 |
+
return []
|
59 |
|
60 |
def update_config(config_data):
|
61 |
for index, provider in enumerate(config_data['providers']):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
if provider.get('project_id'):
|
63 |
provider['base_url'] = 'https://aiplatform.googleapis.com/'
|
64 |
if provider.get('cf_account_id'):
|
65 |
provider['base_url'] = 'https://api.cloudflare.com/'
|
66 |
|
67 |
+
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList([provider.get('api', None)])
|
68 |
+
|
69 |
+
if not provider.get("model"):
|
70 |
+
provider["model"] = update_initial_model(provider['base_url'], provider['api'])
|
71 |
+
|
72 |
+
if provider.get("tools") == None:
|
73 |
+
provider["tools"] = True
|
74 |
|
75 |
config_data['providers'][index] = provider
|
76 |
|
|
|
94 |
api_keys_db[index]['model'] = models
|
95 |
|
96 |
api_list = [item["api"] for item in api_keys_db]
|
97 |
+
# logger.info(json.dumps(config_data, indent=4, ensure_ascii=False))
|
98 |
return config_data, api_keys_db, api_list
|
99 |
|
100 |
# 读取YAML配置文件
|
101 |
async def load_config(app=None):
|
|
|
102 |
try:
|
103 |
+
from ruamel.yaml import YAML
|
104 |
+
yaml = YAML()
|
105 |
+
yaml.preserve_quotes = True
|
106 |
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
107 |
+
with open('api.yaml', 'r', encoding='utf-8') as file:
|
108 |
+
conf = yaml.load(file)
|
109 |
+
|
110 |
+
if conf:
|
111 |
+
config, api_keys_db, api_list = update_config(conf)
|
112 |
+
else:
|
113 |
+
# logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
|
114 |
+
config, api_keys_db, api_list = [], [], []
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
except FileNotFoundError:
|
116 |
logger.error("'api.yaml' not found. Please check the file path.")
|
117 |
config, api_keys_db, api_list = [], [], []
|
|
|
269 |
unique_models = set()
|
270 |
|
271 |
for provider in config["providers"]:
|
272 |
+
model_dict = get_model_dict(provider)
|
273 |
+
for model in model_dict.keys():
|
274 |
if model not in unique_models:
|
275 |
unique_models.add(model)
|
276 |
model_info = {
|
|
|
302 |
# europe-west1
|
303 |
# europe-west4
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
+
c35s = ThreadSafeCircularList(["us-east5", "europe-west1"])
|
307 |
+
c3s = ThreadSafeCircularList(["us-east5", "us-central1", "asia-southeast1"])
|
308 |
+
c3o = ThreadSafeCircularList(["us-east5"])
|
309 |
+
c3h = ThreadSafeCircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
|
310 |
+
gem = ThreadSafeCircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
|
311 |
|
312 |
class BaseAPI:
|
313 |
def __init__(
|