yym68686 commited on
Commit
a04542e
·
1 Parent(s): 6d80128

✨ Feature: Add support for frontend page operation configuration files.

Browse files
Files changed (2) hide show
  1. main.py +496 -18
  2. test/xue/test_home.py +8 -12
main.py CHANGED
@@ -3,12 +3,12 @@ from log_config import logger
3
  import re
4
  import httpx
5
  import secrets
6
- import time as time_module
7
  from contextlib import asynccontextmanager
8
  from starlette.middleware.base import BaseHTTPMiddleware
9
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
- from fastapi import FastAPI, HTTPException, Depends, Request
12
  from fastapi.responses import JSONResponse
13
  from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
14
  from starlette.responses import StreamingResponse as StarletteStreamingResponse
@@ -77,6 +77,13 @@ def _get_default_sql(default):
77
 
78
  @asynccontextmanager
79
  async def lifespan(app: FastAPI):
 
 
 
 
 
 
 
80
  # 启动时的代码
81
  await create_tables()
82
 
@@ -95,6 +102,16 @@ async def lifespan(app: FastAPI):
95
  )
96
  # app.state.client = httpx.AsyncClient(timeout=timeout)
97
  app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
 
 
 
 
 
 
 
 
 
 
98
  yield
99
  # 关闭时的代码
100
  await app.state.client.aclose()
@@ -113,7 +130,6 @@ async def http_exception_handler(request: Request, exc: HTTPException):
113
  import uuid
114
  import json
115
  import asyncio
116
- from time import time
117
  import contextvars
118
  request_info = contextvars.ContextVar('request_info', default={})
119
 
@@ -391,18 +407,19 @@ class StatsMiddleware(BaseHTTPMiddleware):
391
  try:
392
  response = await call_next(request)
393
 
394
- if isinstance(response, (FastAPIStreamingResponse, StarletteStreamingResponse)) or type(response).__name__ == '_StreamingResponse':
395
- response = LoggingStreamingResponse(
396
- content=response.body_iterator,
397
- status_code=response.status_code,
398
- media_type=response.media_type,
399
- headers=response.headers,
400
- current_info=current_info,
401
- )
402
- elif hasattr(response, 'json'):
403
- logger.info(f"Response: {await response.json()}")
404
- else:
405
- logger.info(f"Response: type={type(response).__name__}, status_code={response.status_code}, headers={response.headers}")
 
406
 
407
  return response
408
  finally:
@@ -793,7 +810,7 @@ class InMemoryRateLimiter:
793
  self.requests = defaultdict(list)
794
 
795
  async def is_rate_limited(self, key: str, limit: int, period: int) -> bool:
796
- now = time_module.time()
797
  self.requests[key] = [req for req in self.requests[key] if req > now - period]
798
  if len(self.requests[key]) >= limit:
799
  return True
@@ -910,7 +927,7 @@ async def audio_transcriptions(
910
  traceback.print_exc()
911
  raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
912
 
913
- @app.get("/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
914
  def generate_api_key():
915
  # Define the character set (only alphanumeric)
916
  chars = string.ascii_letters + string.digits
@@ -924,7 +941,7 @@ from datetime import datetime, timedelta, timezone
924
  from sqlalchemy import func, desc, case
925
  from fastapi import Query
926
 
927
- @app.get("/stats", dependencies=[Depends(rate_limit_dependency)])
928
  async def get_stats(
929
  request: Request,
930
  token: str = Depends(verify_admin_api_key),
@@ -1026,6 +1043,467 @@ async def get_stats(
1026
 
1027
  return JSONResponse(content=stats)
1028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1029
  # async def on_fetch(request, env):
1030
  # import asgi
1031
  # return await asgi.fetch(app, request, env)
 
3
  import re
4
  import httpx
5
  import secrets
6
+ from time import time
7
  from contextlib import asynccontextmanager
8
  from starlette.middleware.base import BaseHTTPMiddleware
9
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi import FastAPI, HTTPException, Depends, Request, APIRouter
12
  from fastapi.responses import JSONResponse
13
  from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
14
  from starlette.responses import StreamingResponse as StarletteStreamingResponse
 
77
 
78
  @asynccontextmanager
79
  async def lifespan(app: FastAPI):
80
+ # print("Main app routes:")
81
+ # for route in app.routes:
82
+ # print(f"Route: {route.path}, methods: {route.methods}")
83
+
84
+ # print("\nFrontend router routes:")
85
+ # for route in frontend_router.routes:
86
+ # print(f"Route: {route.path}, methods: {route.methods}")
87
  # 启动时的代码
88
  await create_tables()
89
 
 
102
  )
103
  # app.state.client = httpx.AsyncClient(timeout=timeout)
104
  app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
105
+
106
+ for item in app.state.api_keys_db:
107
+ if item.get("role") == "admin":
108
+ app.state.admin_api_key = item.get("api")
109
+ if not hasattr(app.state, "admin_api_key"):
110
+ if len(app.state.api_keys_db) >= 1:
111
+ app.state.admin_api_key = app.state.api_keys_db[0].get("api")
112
+ else:
113
+ raise Exception("No admin API key found")
114
+
115
  yield
116
  # 关闭时的代码
117
  await app.state.client.aclose()
 
130
  import uuid
131
  import json
132
  import asyncio
 
133
  import contextvars
134
  request_info = contextvars.ContextVar('request_info', default={})
135
 
 
407
  try:
408
  response = await call_next(request)
409
 
410
+ if request.url.path.startswith("/v1"):
411
+ if isinstance(response, (FastAPIStreamingResponse, StarletteStreamingResponse)) or type(response).__name__ == '_StreamingResponse':
412
+ response = LoggingStreamingResponse(
413
+ content=response.body_iterator,
414
+ status_code=response.status_code,
415
+ media_type=response.media_type,
416
+ headers=response.headers,
417
+ current_info=current_info,
418
+ )
419
+ elif hasattr(response, 'json'):
420
+ logger.info(f"Response: {await response.json()}")
421
+ else:
422
+ logger.info(f"Response: type={type(response).__name__}, status_code={response.status_code}, headers={response.headers}")
423
 
424
  return response
425
  finally:
 
810
  self.requests = defaultdict(list)
811
 
812
  async def is_rate_limited(self, key: str, limit: int, period: int) -> bool:
813
+ now = time()
814
  self.requests[key] = [req for req in self.requests[key] if req > now - period]
815
  if len(self.requests[key]) >= limit:
816
  return True
 
927
  traceback.print_exc()
928
  raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
929
 
930
+ @app.get("/v1/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
931
  def generate_api_key():
932
  # Define the character set (only alphanumeric)
933
  chars = string.ascii_letters + string.digits
 
941
  from sqlalchemy import func, desc, case
942
  from fastapi import Query
943
 
944
+ @app.get("/v1/stats", dependencies=[Depends(rate_limit_dependency)])
945
  async def get_stats(
946
  request: Request,
947
  token: str = Depends(verify_admin_api_key),
 
1043
 
1044
  return JSONResponse(content=stats)
1045
 
1046
+
1047
+
1048
+ from fastapi import FastAPI, Request
1049
+ from fastapi import Form as FastapiForm, HTTPException, Depends
1050
+ from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
1051
+ from fastapi.security import APIKeyHeader
1052
+ from typing import Optional, List
1053
+
1054
+ from xue import HTML, Head, Body, Div, xue_initialize, Script
1055
+ from xue.components.menubar import (
1056
+ Menubar, MenubarMenu, MenubarTrigger, MenubarContent,
1057
+ MenubarItem, MenubarSeparator
1058
+ )
1059
+ from xue.components import input
1060
+ from xue.components import dropdown, sheet, form, button, checkbox
1061
+ from xue.components.model_config_row import model_config_row
1062
+ # import sys
1063
+ # import os
1064
+ # sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
1065
+ from components.provider_table import data_table
1066
+
1067
+ from ruamel.yaml import YAML
1068
+ yaml = YAML()
1069
+ yaml.preserve_quotes = True
1070
+ yaml.indent(mapping=2, sequence=4, offset=2)
1071
+
1072
+
1073
+ frontend_router = APIRouter()
1074
+
1075
+ API_KEY_NAME = "X-API-Key"
1076
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
1077
+ async def get_api_key(request: Request, x_api_key: Optional[str] = Depends(api_key_header)):
1078
+ if not x_api_key:
1079
+ x_api_key = request.cookies.get("x_api_key") or request.query_params.get("x_api_key")
1080
+ # print(f"Cookie x_api_key: {request.cookies.get('x_api_key')}") # 添加此行
1081
+ # print(f"Query param x_api_key: {request.query_params.get('x_api_key')}") # 添加此行
1082
+ # print(f"Header x_api_key: {x_api_key}") # 添加此行
1083
+ # logger.info(f"x_api_key: {x_api_key} {x_api_key == 'your_admin_api_key'}")
1084
+
1085
+ if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥
1086
+ return x_api_key
1087
+ else:
1088
+ return None
1089
+
1090
+ async def frontend_rate_limit_dependency(request: Request, x_api_key: str = Depends(get_api_key)):
1091
+ token = x_api_key if x_api_key else None
1092
+ limit, period = 100, 60
1093
+
1094
+ # 使用 IP 地址和 token(如果有)作为限制键
1095
+ client_ip = request.client.host
1096
+ rate_limit_key = f"{client_ip}:{token}" if token else client_ip
1097
+
1098
+ if await rate_limiter.is_rate_limited(rate_limit_key, limit, period):
1099
+ raise HTTPException(status_code=429, detail="Too many requests")
1100
+
1101
+ # def get_backend_router_api_list():
1102
+ # api_list = []
1103
+ # for route in frontend_router.routes:
1104
+ # api_list.append({
1105
+ # "path": f"/api{route.path}", # 加上前缀
1106
+ # "method": route.methods,
1107
+ # "name": route.name,
1108
+ # "summary": route.summary
1109
+ # })
1110
+ # return api_list
1111
+
1112
+ # @app.get("/backend-router-api-list")
1113
+ # async def backend_router_api_list():
1114
+ # return get_backend_router_api_list()
1115
+
1116
+ xue_initialize(tailwind=True)
1117
+
1118
+ API_YAML_PATH = "./api.yaml"
1119
+
1120
+ data_table_columns = [
1121
+ # {"label": "Status", "value": "status", "sortable": True},
1122
+ {"label": "Provider", "value": "provider", "sortable": True},
1123
+ {"label": "Base url", "value": "base_url", "sortable": True},
1124
+ # {"label": "Engine", "value": "engine", "sortable": True},
1125
+ {"label": "Tools", "value": "tools", "sortable": True},
1126
+ ]
1127
+
1128
+ @frontend_router.get("/login", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1129
+ async def login_page():
1130
+ return HTML(
1131
+ Head(title="登录"),
1132
+ Body(
1133
+ Div(
1134
+ form.Form(
1135
+ form.FormField("API Key", "x_api_key", type="password", placeholder="输入API密钥", required=True),
1136
+ Div(id="error-message", class_="text-red-500 mt-2"),
1137
+ Div(
1138
+ button.button("提交", variant="primary", type="submit"),
1139
+ class_="flex justify-end mt-4"
1140
+ ),
1141
+ hx_post="/verify-api-key",
1142
+ hx_target="#error-message",
1143
+ hx_swap="innerHTML",
1144
+ class_="space-y-4"
1145
+ ),
1146
+ class_="container mx-auto p-4 max-w-md"
1147
+ )
1148
+ )
1149
+ ).render()
1150
+
1151
+
1152
+ @frontend_router.post("/verify-api-key", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1153
+ async def verify_api_key(x_api_key: str = FastapiForm(...)):
1154
+ if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥
1155
+ response = JSONResponse(content={"success": True})
1156
+ response.headers["HX-Redirect"] = "/" # 添加这一行
1157
+ response.set_cookie(
1158
+ key="x_api_key",
1159
+ value=x_api_key,
1160
+ httponly=True,
1161
+ max_age=1800, # 30分钟
1162
+ secure=False, # 在开发环境中设置为False,生产环境中使用HTTPS时设置为True
1163
+ samesite="lax" # 改为"lax"以允许重定向时携带cookie
1164
+ )
1165
+ return response
1166
+ else:
1167
+ return Div("无效的API密钥", class_="text-red-500").render()
1168
+
1169
+ @frontend_router.get("/", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1170
+ async def root(x_api_key: str = Depends(get_api_key)):
1171
+ if not x_api_key:
1172
+ return RedirectResponse(url="/login", status_code=303)
1173
+
1174
+ result = HTML(
1175
+ Head(
1176
+ Script("""
1177
+ document.addEventListener('DOMContentLoaded', function() {
1178
+ const filterInput = document.getElementById('users-table-filter');
1179
+ filterInput.addEventListener('input', function() {
1180
+ const filterValue = this.value;
1181
+ htmx.ajax('GET', `/filter-table?filter=${filterValue}`, '#users-table');
1182
+ });
1183
+ });
1184
+ """),
1185
+ title="Menubar Example"
1186
+ ),
1187
+ Body(
1188
+ Div(
1189
+ Menubar(
1190
+ MenubarMenu(
1191
+ MenubarTrigger("File", "file-menu"),
1192
+ MenubarContent(
1193
+ MenubarItem("New Tab", shortcut="⌘T"),
1194
+ MenubarItem("New Window", shortcut="⌘N"),
1195
+ MenubarItem("New Incognito Window", disabled=True),
1196
+ MenubarSeparator(),
1197
+ MenubarItem("Print...", shortcut="⌘P"),
1198
+ ),
1199
+ id="file-menu"
1200
+ ),
1201
+ MenubarMenu(
1202
+ MenubarTrigger("Edit", "edit-menu"),
1203
+ MenubarContent(
1204
+ MenubarItem("Undo", shortcut="⌘Z"),
1205
+ MenubarItem("Redo", shortcut="⇧⌘Z"),
1206
+ MenubarSeparator(),
1207
+ MenubarItem("Cut"),
1208
+ MenubarItem("Copy"),
1209
+ MenubarItem("Paste"),
1210
+ ),
1211
+ id="edit-menu"
1212
+ ),
1213
+ MenubarMenu(
1214
+ MenubarTrigger("View", "view-menu"),
1215
+ MenubarContent(
1216
+ MenubarItem("Always Show Bookmarks Bar"),
1217
+ MenubarItem("Always Show Full URLs"),
1218
+ MenubarSeparator(),
1219
+ MenubarItem("Reload", shortcut="⌘R"),
1220
+ MenubarItem("Force Reload", shortcut="⇧⌘R", disabled=True),
1221
+ MenubarSeparator(),
1222
+ MenubarItem("Toggle Fullscreen"),
1223
+ MenubarItem("Hide Sidebar"),
1224
+ ),
1225
+ id="view-menu"
1226
+ ),
1227
+ ),
1228
+ class_="p-4"
1229
+ ),
1230
+ Div(
1231
+ data_table(data_table_columns, app.state.config["providers"], "users-table"),
1232
+ class_="p-4"
1233
+ ),
1234
+ Div(id="sheet-container"), # 这里是 sheet 将被加载的地方
1235
+ class_="container mx-auto",
1236
+ id="body"
1237
+ )
1238
+ ).render()
1239
+ # print(result)
1240
+ return result
1241
+
1242
+ @frontend_router.get("/dropdown-menu/{menu_id}/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1243
+ async def get_columns_menu(menu_id: str, row_id: str):
1244
+ columns = [
1245
+ {
1246
+ "label": "Edit",
1247
+ "value": "edit",
1248
+ "hx-get": f"/edit-sheet/{row_id}",
1249
+ "hx-target": "#sheet-container",
1250
+ "hx-swap": "innerHTML"
1251
+ },
1252
+ {
1253
+ "label": "Duplicate",
1254
+ "value": "duplicate",
1255
+ "hx-post": f"/duplicate/{row_id}",
1256
+ "hx-target": "body",
1257
+ "hx-swap": "outerHTML"
1258
+ },
1259
+ {
1260
+ "label": "Delete",
1261
+ "value": "delete",
1262
+ "hx-delete": f"/delete/{row_id}",
1263
+ "hx-target": "body",
1264
+ "hx-swap": "outerHTML",
1265
+ "hx-confirm": "确定要删除这个配置吗?"
1266
+ },
1267
+ ]
1268
+ result = dropdown.dropdown_menu_content(menu_id, columns).render()
1269
+ print(result)
1270
+ return result
1271
+
1272
+ @frontend_router.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1273
+ async def get_columns_menu(menu_id: str):
1274
+ result = dropdown.dropdown_menu_content(menu_id, data_table_columns).render()
1275
+ print(result)
1276
+ return result
1277
+
1278
+ @frontend_router.get("/filter-table", response_class=HTMLResponse)
1279
+ async def filter_table(filter: str = ""):
1280
+ filtered_data = [
1281
+ provider for provider in app.state.config["providers"]
1282
+ if filter.lower() in str(provider["provider"]).lower() or
1283
+ filter.lower() in str(provider["base_url"]).lower() or
1284
+ filter.lower() in str(provider["tools"]).lower()
1285
+ ]
1286
+ return data_table(data_table_columns, filtered_data, "users-table", with_filter=False).render()
1287
+
1288
+ @frontend_router.post("/add-model", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1289
+ async def add_model():
1290
+ new_model_id = f"model{hash(str(time()))}" # 生成一个唯一的ID
1291
+ new_model = model_config_row(new_model_id).render()
1292
+ return new_model
1293
+
1294
+ @frontend_router.get("/edit-sheet/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1295
+ async def get_edit_sheet(row_id: str, x_api_key: str = Depends(get_api_key)):
1296
+ row_data = get_row_data(row_id)
1297
+ print("row_data", row_data)
1298
+
1299
+ model_list = []
1300
+ for index, model in enumerate(row_data["model"]):
1301
+ if isinstance(model, str):
1302
+ model_list.append(model_config_row(f"model{index}", model, "", True))
1303
+ if isinstance(model, dict):
1304
+ # print("model", model, list(model.items())[0])
1305
+ key, value = list(model.items())[0]
1306
+ model_list.append(model_config_row(f"model{index}", key, value, True))
1307
+
1308
+ sheet_id = "edit-sheet"
1309
+ edit_sheet_content = sheet.SheetContent(
1310
+ sheet.SheetHeader(
1311
+ sheet.SheetTitle("Edit Item"),
1312
+ sheet.SheetDescription("Make changes to your item here.")
1313
+ ),
1314
+ sheet.SheetBody(
1315
+ Div(
1316
+ form.Form(
1317
+ form.FormField("Provider", "provider", value=row_data["provider"], placeholder="Enter provider name", required=True),
1318
+ form.FormField("Base URL", "base_url", value=row_data["base_url"], placeholder="Enter base URL", required=True),
1319
+ form.FormField("API Key", "api_key", value=row_data["api"], type="text", placeholder="Enter API key"),
1320
+ Div(
1321
+ Div("Models", class_="text-lg font-semibold mb-2"),
1322
+ Div(
1323
+ *model_list,
1324
+ id="models-container"
1325
+ ),
1326
+ button.button(
1327
+ "Add Model",
1328
+ class_="mt-2",
1329
+ hx_post="/add-model",
1330
+ hx_target="#models-container",
1331
+ hx_swap="beforeend"
1332
+ ),
1333
+ class_="mb-4"
1334
+ ),
1335
+ Div(
1336
+ checkbox.checkbox("tools", "Enable Tools", checked=row_data["tools"], name="tools"),
1337
+ class_="mb-4"
1338
+ ),
1339
+ form.FormField("Notes", "notes", value=row_data.get("notes", ""), placeholder="Enter any additional notes"),
1340
+ Div(
1341
+ button.button("Submit", variant="primary", type="submit"),
1342
+ button.button("Cancel", variant="outline", type="button", class_="ml-2", onclick=f"toggleSheet('{sheet_id}')"),
1343
+ class_="flex justify-end mt-4"
1344
+ ),
1345
+ hx_post=f"/submit/{row_id}",
1346
+ hx_swap="outerHTML",
1347
+ hx_target="body",
1348
+ class_="space-y-4"
1349
+ ),
1350
+ class_="container mx-auto p-4 max-w-2xl"
1351
+ )
1352
+ )
1353
+ )
1354
+
1355
+ result = sheet.Sheet(
1356
+ sheet_id,
1357
+ Div(),
1358
+ edit_sheet_content,
1359
+ width="80%",
1360
+ max_width="800px"
1361
+ ).render()
1362
+ return result
1363
+
1364
+ @frontend_router.get("/add-provider-sheet", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1365
+ async def get_add_provider_sheet():
1366
+ edit_sheet_content = sheet.SheetContent(
1367
+ sheet.SheetHeader(
1368
+ sheet.SheetTitle("Add New Provider"),
1369
+ sheet.SheetDescription("Enter details for the new provider.")
1370
+ ),
1371
+ sheet.SheetBody(
1372
+ Div(
1373
+ form.Form(
1374
+ form.FormField("Provider", "provider", placeholder="Enter provider name", required=True),
1375
+ form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True),
1376
+ form.FormField("API Key", "api_key", type="text", placeholder="Enter API key"),
1377
+ Div(
1378
+ Div("Models", class_="text-lg font-semibold mb-2"),
1379
+ Div(id="models-container"),
1380
+ button.button(
1381
+ "Add Model",
1382
+ class_="mt-2",
1383
+ hx_post="/add-model",
1384
+ hx_target="#models-container",
1385
+ hx_swap="beforeend"
1386
+ ),
1387
+ class_="mb-4"
1388
+ ),
1389
+ Div(
1390
+ checkbox.checkbox("tools", "Enable Tools", name="tools"),
1391
+ class_="mb-4"
1392
+ ),
1393
+ form.FormField("Notes", "notes", placeholder="Enter any additional notes"),
1394
+ Div(
1395
+ button.button("Submit", variant="primary", type="submit"),
1396
+ button.button("Cancel", variant="outline", class_="ml-2"),
1397
+ class_="flex justify-end mt-4"
1398
+ ),
1399
+ hx_post="/submit/new",
1400
+ hx_swap="outerHTML",
1401
+ hx_target="body",
1402
+ class_="space-y-4"
1403
+ ),
1404
+ class_="container mx-auto p-4 max-w-2xl"
1405
+ )
1406
+ )
1407
+ )
1408
+
1409
+ result = sheet.Sheet(
1410
+ "add-provider-sheet",
1411
+ Div(),
1412
+ edit_sheet_content,
1413
+ width="80%",
1414
+ max_width="800px"
1415
+ ).render()
1416
+ return result
1417
+
1418
+ def get_row_data(row_id):
1419
+ index = int(row_id)
1420
+ # print(app.state.config["providers"])
1421
+ return app.state.config["providers"][index]
1422
+
1423
+ def update_row_data(row_id, updated_data):
1424
+ print(row_id, updated_data)
1425
+ index = int(row_id)
1426
+ app.state.config["providers"][index] = updated_data
1427
+ save_api_yaml()
1428
+
1429
+ def save_api_yaml():
1430
+ with open(API_YAML_PATH, "w", encoding="utf-8") as f:
1431
+ yaml.dump(app.state.config, f)
1432
+
1433
+ @frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1434
+ async def submit_form(
1435
+ row_id: str,
1436
+ request: Request,
1437
+ provider: str = FastapiForm(...),
1438
+ base_url: str = FastapiForm(...),
1439
+ api_key: Optional[str] = FastapiForm(None),
1440
+ tools: Optional[str] = FastapiForm(None),
1441
+ notes: Optional[str] = FastapiForm(None),
1442
+ x_api_key: str = Depends(get_api_key)
1443
+ ):
1444
+ form_data = await request.form()
1445
+
1446
+ # 收集模型数据
1447
+ models = []
1448
+ for key, value in form_data.items():
1449
+ if key.startswith("model_name_"):
1450
+ model_id = key.split("_")[-1]
1451
+ enabled = form_data.get(f"model_enabled_{model_id}") == "on"
1452
+ rename = form_data.get(f"model_rename_{model_id}")
1453
+ if value:
1454
+ if rename:
1455
+ models.append({value: rename})
1456
+ else:
1457
+ models.append(value)
1458
+
1459
+ updated_data = {
1460
+ "provider": provider,
1461
+ "base_url": base_url,
1462
+ "api": api_key,
1463
+ "model": models,
1464
+ "tools": tools == "on",
1465
+ "notes": notes,
1466
+ }
1467
+
1468
+ print("updated_data", updated_data)
1469
+
1470
+ if row_id == "new":
1471
+ # 添加新提供者
1472
+ app.state.config["providers"].append(updated_data)
1473
+ else:
1474
+ # 更新现有提供者
1475
+ update_row_data(row_id, updated_data)
1476
+
1477
+ # 保存更新后的配置
1478
+ save_api_yaml()
1479
+
1480
+ return await root()
1481
+
1482
+ @frontend_router.post("/duplicate/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1483
+ async def duplicate_row(row_id: str):
1484
+ index = int(row_id)
1485
+ original_data = app.state.config["providers"][index]
1486
+ new_data = original_data.copy()
1487
+ new_data["provider"] += "-copy"
1488
+ app.state.config["providers"].insert(index + 1, new_data)
1489
+
1490
+ # 保存更新后的配置
1491
+ save_api_yaml()
1492
+
1493
+ return await root()
1494
+
1495
+ @frontend_router.delete("/delete/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1496
+ async def delete_row(row_id: str):
1497
+ index = int(row_id)
1498
+ del app.state.config["providers"][index]
1499
+
1500
+ # 保存更新后的配置
1501
+ save_api_yaml()
1502
+
1503
+ return await root()
1504
+
1505
+ app.include_router(frontend_router, tags=["frontend"])
1506
+
1507
  # async def on_fetch(request, env):
1508
  # import asgi
1509
  # return await asgi.fetch(app, request, env)
test/xue/test_home.py CHANGED
@@ -33,6 +33,7 @@ import logging
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger(__name__)
35
 
 
36
  class RequestBodyLoggerMiddleware(BaseHTTPMiddleware):
37
  async def dispatch(self, request: Request, call_next):
38
  if request.method == "POST" and request.url.path.startswith("/submit/"):
@@ -47,7 +48,6 @@ from utils import load_config
47
  from contextlib import asynccontextmanager
48
  @asynccontextmanager
49
  async def lifespan(app: FastAPI):
50
- # app.state.client = httpx.AsyncClient(timeout=timeout)
51
  app.state.config, app.state.api_keys_db, app.state.api_list = await load_config()
52
  for item in app.state.api_keys_db:
53
  if item.get("role") == "admin":
@@ -58,10 +58,6 @@ async def lifespan(app: FastAPI):
58
  else:
59
  raise Exception("No admin API key found")
60
 
61
- global data
62
- # providers_data = app.state.config["providers"]
63
-
64
- # print("data", data)
65
  yield
66
  # 关闭时的代码
67
  await app.state.client.aclose()
@@ -393,7 +389,10 @@ def update_row_data(row_id, updated_data):
393
  print(row_id, updated_data)
394
  index = int(row_id)
395
  app.state.config["providers"][index] = updated_data
396
- with open("./api1.yaml", "w", encoding="utf-8") as f:
 
 
 
397
  yaml.dump(app.state.config, f)
398
 
399
  @app.post("/submit/{row_id}", response_class=HTMLResponse)
@@ -441,8 +440,7 @@ async def submit_form(
441
  update_row_data(row_id, updated_data)
442
 
443
  # 保存更新后的配置
444
- with open("./api1.yaml", "w", encoding="utf-8") as f:
445
- yaml.dump(app.state.config, f)
446
 
447
  return await root()
448
 
@@ -455,8 +453,7 @@ async def duplicate_row(row_id: str):
455
  app.state.config["providers"].insert(index + 1, new_data)
456
 
457
  # 保存更新后的配置
458
- with open("./api1.yaml", "w", encoding="utf-8") as f:
459
- yaml.dump(app.state.config, f)
460
 
461
  return await root()
462
 
@@ -466,8 +463,7 @@ async def delete_row(row_id: str):
466
  del app.state.config["providers"][index]
467
 
468
  # 保存更新后的配置
469
- with open("./api1.yaml", "w", encoding="utf-8") as f:
470
- yaml.dump(app.state.config, f)
471
 
472
  return await root()
473
 
 
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger(__name__)
35
 
36
+ API_YAML_PATH = "./api.yaml"
37
  class RequestBodyLoggerMiddleware(BaseHTTPMiddleware):
38
  async def dispatch(self, request: Request, call_next):
39
  if request.method == "POST" and request.url.path.startswith("/submit/"):
 
48
  from contextlib import asynccontextmanager
49
  @asynccontextmanager
50
  async def lifespan(app: FastAPI):
 
51
  app.state.config, app.state.api_keys_db, app.state.api_list = await load_config()
52
  for item in app.state.api_keys_db:
53
  if item.get("role") == "admin":
 
58
  else:
59
  raise Exception("No admin API key found")
60
 
 
 
 
 
61
  yield
62
  # 关闭时的代码
63
  await app.state.client.aclose()
 
389
  print(row_id, updated_data)
390
  index = int(row_id)
391
  app.state.config["providers"][index] = updated_data
392
+ save_api_yaml()
393
+
394
+ def save_api_yaml():
395
+ with open(API_YAML_PATH, "w", encoding="utf-8") as f:
396
  yaml.dump(app.state.config, f)
397
 
398
  @app.post("/submit/{row_id}", response_class=HTMLResponse)
 
440
  update_row_data(row_id, updated_data)
441
 
442
  # 保存更新后的配置
443
+ save_api_yaml()
 
444
 
445
  return await root()
446
 
 
453
  app.state.config["providers"].insert(index + 1, new_data)
454
 
455
  # 保存更新后的配置
456
+ save_api_yaml()
 
457
 
458
  return await root()
459
 
 
463
  del app.state.config["providers"][index]
464
 
465
  # 保存更新后的配置
466
+ save_api_yaml()
 
467
 
468
  return await root()
469