from log_config import logger |
import copy |
import httpx |
import secrets |
from time import time |
from contextlib import asynccontextmanager |
from starlette.middleware.base import BaseHTTPMiddleware |
from fastapi.middleware.cors import CORSMiddleware |
from fastapi import FastAPI, HTTPException, Depends, Request, APIRouter |
from fastapi.responses import JSONResponse |
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse |
from starlette.responses import StreamingResponse as StarletteStreamingResponse |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
from fastapi.exceptions import RequestValidationError |
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, TextToSpeechRequest, UnifiedRequest, EmbeddingRequest |
from request import get_payload |
from response import fetch_response, fetch_response_stream |
from utils import ( |
safe_get, |
load_config, |
save_api_yaml, |
get_model_dict, |
post_all_models, |
get_user_rate_limit, |
circular_list_encoder, |
error_handling_wrapper, |
rate_limiter, |
provider_api_circular_list, |
) |
from collections import defaultdict |
from typing import List, Dict, Union |
from urllib.parse import urlparse |
import os |
import string |
import json |
DEFAULT_TIMEOUT = int(os.getenv("TIMEOUT", 100)) |
is_debug = bool(os.getenv("DEBUG", False)) |
from sqlalchemy import inspect, text |
from sqlalchemy.sql import sqltypes |
DISABLE_DATABASE = os.getenv("DISABLE_DATABASE", "false").lower() == "true" |
IS_VERCEL = os.path.dirname(os.path.abspath(__file__)).startswith('/var/task') |
logger.info("IS_VERCEL: %s", IS_VERCEL) |
async def create_tables(): |
return |
async with db_engine.begin() as conn: |
await conn.run_sync(Base.metadata.create_all) |
def check_and_add_columns(connection): |
inspector = inspect(connection) |
for table in [RequestStat, ChannelStat]: |
table_name = table.__tablename__ |
existing_columns = {col['name']: col['type'] for col in inspector.get_columns(table_name)} |
for column_name, column in table.__table__.columns.items(): |
if column_name not in existing_columns: |
col_type = _map_sa_type_to_sql_type(column.type) |
default = _get_default_sql(column.default) |
connection.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {col_type}{default}")) |
await conn.run_sync(check_and_add_columns) |
def _map_sa_type_to_sql_type(sa_type): |
type_map = { |
sqltypes.Integer: "INTEGER", |
sqltypes.String: "TEXT", |
sqltypes.Float: "REAL", |
sqltypes.Boolean: "BOOLEAN", |
sqltypes.DateTime: "DATETIME", |
sqltypes.Text: "TEXT" |
} |
return type_map.get(type(sa_type), "TEXT") |
def _get_default_sql(default): |
if default is None: |
return "" |
if isinstance(default.arg, bool): |
return f" DEFAULT {str(default.arg).upper()}" |
if isinstance(default.arg, (int, float)): |
return f" DEFAULT {default.arg}" |
if isinstance(default.arg, str): |
return f" DEFAULT '{default.arg}'" |
return "" |
@asynccontextmanager |
async def lifespan(app: FastAPI): |
await create_tables() |
yield |
if hasattr(app.state, 'client_manager'): |
await app.state.client_manager.close() |
app = FastAPI(lifespan=lifespan, debug=is_debug) |
def generate_markdown_docs(): |
openapi_schema = app.openapi() |
markdown = f"# {openapi_schema['info']['title']}\n\n" |
markdown += f"Version: {openapi_schema['info']['version']}\n\n" |
markdown += f"{openapi_schema['info'].get('description', '')}\n\n" |
markdown += "## API Endpoints\n\n" |
paths = openapi_schema['paths'] |
for path, path_info in paths.items(): |
for method, operation in path_info.items(): |
markdown += f"### {method.upper()} {path}\n\n" |
markdown += f"{operation.get('summary', '')}\n\n" |
markdown += f"{operation.get('description', '')}\n\n" |
if 'parameters' in operation: |
markdown += "Parameters:\n" |
for param in operation['parameters']: |
markdown += f"- {param['name']} ({param['in']}): {param.get('description', '')}\n" |
markdown += "\n---\n\n" |
return markdown |
@app.get("/docs/markdown") |
async def get_markdown_docs(): |
markdown = generate_markdown_docs() |
return Response( |
content=markdown, |
media_type="text/markdown" |
) |
@app.exception_handler(HTTPException) |
async def http_exception_handler(request: Request, exc: HTTPException): |
if exc.status_code == 404: |
logger.error(f"404 Error: {exc.detail}") |
return JSONResponse( |
status_code=exc.status_code, |
content={"message": exc.detail}, |
) |
import uuid |
import asyncio |
import contextvars |
request_info = contextvars.ContextVar('request_info', default={}) |
async def parse_request_body(request: Request): |
if request.method == "POST" and "application/json" in request.headers.get("content-type", ""): |
try: |
return await request.json() |
except json.JSONDecodeError: |
return None |
return None |
class ChannelManager: |
def __init__(self, cooldown_period=300): |
self._excluded_models = defaultdict(lambda: None) |
self.cooldown_period = cooldown_period |
async def exclude_model(self, provider: str, model: str): |
model_key = f"{provider}/{model}" |
self._excluded_models[model_key] = datetime.now() |
async def is_model_excluded(self, provider: str, model: str) -> bool: |
model_key = f"{provider}/{model}" |
excluded_time = self._excluded_models[model_key] |
if not excluded_time: |
return False |
if datetime.now() - excluded_time > timedelta(seconds=self.cooldown_period): |
del self._excluded_models[model_key] |
return False |
return True |
async def get_available_providers(self, providers: list) -> list: |
"""过滤出可用的providers,仅排除不可用的模型""" |
available_providers = [] |
for provider in providers: |
provider_name = provider['provider'] |
model_dict = provider['model'][0] |
target_model = list(model_dict.values())[0] |
if not await self.is_model_excluded(provider_name, target_model): |
available_providers.append(provider) |
return available_providers |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession |
from sqlalchemy.orm import declarative_base, sessionmaker |
from sqlalchemy import Column, Integer, String, Float, DateTime, select, Boolean, Text |
from sqlalchemy.sql import func |
Base = declarative_base() |
class RequestStat(Base): |
__tablename__ = 'request_stats' |
id = Column(Integer, primary_key=True) |
request_id = Column(String) |
endpoint = Column(String) |
client_ip = Column(String) |
process_time = Column(Float) |
first_response_time = Column(Float) |
provider = Column(String) |
model = Column(String) |
api_key = Column(String) |
is_flagged = Column(Boolean, default=False) |
text = Column(Text) |
prompt_tokens = Column(Integer, default=0) |
completion_tokens = Column(Integer, default=0) |
total_tokens = Column(Integer, default=0) |
timestamp = Column(DateTime(timezone=True), server_default=func.now()) |
class ChannelStat(Base): |
__tablename__ = 'channel_stats' |
id = Column(Integer, primary_key=True) |
request_id = Column(String) |
provider = Column(String) |
model = Column(String) |
api_key = Column(String) |
success = Column(Boolean, default=False) |
timestamp = Column(DateTime(timezone=True), server_default=func.now()) |
db_path = os.getenv('DB_PATH', '/data/stats.db') |
data_dir = os.path.dirname(db_path) |
os.makedirs(data_dir, exist_ok=True) |
db_engine = create_async_engine('sqlite+aiosqlite:///' + db_path, echo=is_debug) |
async_session = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) |
from starlette.types import Scope, Receive, Send |
from starlette.responses import Response |
from asyncio import Semaphore |
db_semaphore = Semaphore(1) |
async def update_stats(current_info): |
return |
try: |
async with db_semaphore: |
async with async_session() as session: |
async with session.begin(): |
try: |
columns = [column.key for column in RequestStat.__table__.columns] |
filtered_info = {k: v for k, v in current_info.items() if k in columns} |
new_request_stat = RequestStat(**filtered_info) |
session.add(new_request_stat) |
await session.commit() |
except Exception as e: |
await session.rollback() |
logger.error(f"Error updating stats: {str(e)}") |
if is_debug: |
import traceback |
traceback.print_exc() |
except Exception as e: |
logger.error(f"Error acquiring database lock: {str(e)}") |
if is_debug: |
import traceback |
traceback.print_exc() |
async def update_channel_stats(request_id, provider, model, api_key, success): |
return |
try: |
async with db_semaphore: |
async with async_session() as session: |
async with session.begin(): |
try: |
channel_stat = ChannelStat( |
request_id=request_id, |
provider=provider, |
model=model, |
api_key=api_key, |
success=success, |
) |
session.add(channel_stat) |
await session.commit() |
except Exception as e: |
await session.rollback() |
logger.error(f"Error updating channel stats: {str(e)}") |
if is_debug: |
import traceback |
traceback.print_exc() |
except Exception as e: |
logger.error(f"Error acquiring database lock: {str(e)}") |
if is_debug: |
import traceback |
traceback.print_exc() |
class LoggingStreamingResponse(Response): |
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None): |
super().__init__(content=None, status_code=status_code, headers=headers, media_type=media_type) |
self.body_iterator = content |
self._closed = False |
self.current_info = current_info |
if 'content-length' in self.headers: |
del self.headers['content-length'] |
self.headers['transfer-encoding'] = 'chunked' |
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
await send({ |
'type': 'http.response.start', |
'status': self.status_code, |
'headers': self.raw_headers, |
}) |
try: |
async for chunk in self._logging_iterator(): |
await send({ |
'type': 'http.response.body', |
'body': chunk, |
'more_body': True, |
}) |
finally: |
await send({ |
'type': 'http.response.body', |
'body': b'', |
'more_body': False, |
}) |
if hasattr(self.body_iterator, 'aclose') and not self._closed: |
await self.body_iterator.aclose() |
self._closed = True |
process_time = time() - self.current_info["start_time"] |
self.current_info["process_time"] = process_time |
await update_stats(self.current_info) |
async def _logging_iterator(self): |
try: |
async for chunk in self.body_iterator: |
if isinstance(chunk, str): |
chunk = chunk.encode('utf-8') |
if self.current_info.get("endpoint") == "/v1/audio/speech": |
yield chunk |
continue |
line = chunk.decode('utf-8') |
if is_debug: |
logger.info(f"{line.encode('utf-8').decode('unicode_escape')}") |
if line.startswith("data:"): |
line = line.lstrip("data: ") |
if not line.startswith("[DONE]") and not line.startswith("OK"): |
try: |
resp: dict = json.loads(line) |
input_tokens = safe_get(resp, "message", "usage", "input_tokens", default=0) |
input_tokens = safe_get(resp, "usage", "prompt_tokens", default=0) |
output_tokens = safe_get(resp, "usage", "completion_tokens", default=0) |
total_tokens = input_tokens + output_tokens |
self.current_info["prompt_tokens"] = input_tokens |
self.current_info["completion_tokens"] = output_tokens |
self.current_info["total_tokens"] = total_tokens |
except Exception as e: |
logger.error(f"Error parsing response: {str(e)}, line: {repr(line)}") |
continue |
yield chunk |
except Exception as e: |
raise |
finally: |
logger.debug("_logging_iterator finished") |
async def close(self): |
if not self._closed: |
self._closed = True |
if hasattr(self.body_iterator, 'aclose'): |
await self.body_iterator.aclose() |
class StatsMiddleware(BaseHTTPMiddleware): |
def __init__(self, app): |
super().__init__(app) |
async def dispatch(self, request: Request, call_next): |
start_time = time() |
enable_moderation = False |
config = app.state.config |
if request.headers.get("x-api-key"): |
token = request.headers.get("x-api-key") |
elif request.headers.get("Authorization"): |
api_split_list = request.headers.get("Authorization").split(" ") |
if len(api_split_list) > 1: |
token = api_split_list[1] |
else: |
return JSONResponse( |
status_code=403, |
content={"error": "Invalid or missing API Key"} |
) |
else: |
token = None |
api_index = None |
if token: |
try: |
api_list = app.state.api_list |
api_index = api_list.index(token) |
except ValueError: |
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None) |
pass |
if api_index is not None: |
enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False) |
else: |
enable_moderation = config.get('ENABLE_MODERATION', False) |
request_id = str(uuid.uuid4()) |
request_info_data = { |
"request_id": request_id, |
"start_time": start_time, |
"endpoint": f"{request.method} {request.url.path}", |
"client_ip": request.client.host, |
"process_time": 0, |
"first_response_time": -1, |
"provider": None, |
"model": None, |
"success": False, |
"api_key": token, |
"is_flagged": False, |
"text": None, |
"prompt_tokens": 0, |
"completion_tokens": 0, |
"total_tokens": 0 |
} |
current_request_info = request_info.set(request_info_data) |
current_info = request_info.get() |
parsed_body = await parse_request_body(request) |
if parsed_body: |
try: |
request_model = UnifiedRequest.model_validate(parsed_body).data |
model = request_model.model |
current_info["model"] = model |
moderated_content = None |
if request_model.request_type == "chat": |
moderated_content = request_model.get_last_text_message() |
elif request_model.request_type == "image": |
moderated_content = request_model.prompt |
elif request_model.request_type == "tts": |
moderated_content = request_model.input |
elif request_model.request_type == "moderation": |
pass |
elif request_model.request_type == "embedding": |
if isinstance(request_model.input, list) and len(request_model.input) > 0 and isinstance(request_model.input[0], str): |
moderated_content = "\n".join(request_model.input) |
else: |
moderated_content = request_model.input |
else: |
logger.error(f"Unknown request type: {request_model.request_type}") |
if moderated_content: |
current_info["text"] = moderated_content |
if enable_moderation and moderated_content: |
moderation_response = await self.moderate_content(moderated_content, api_index) |
is_flagged = moderation_response.get('results', [{}])[0].get('flagged', False) |
if is_flagged: |
logger.error(f"Content did not pass the moral check: %s", moderated_content) |
process_time = time() - start_time |
current_info["process_time"] = process_time |
current_info["is_flagged"] = is_flagged |
await update_stats(current_info) |
return JSONResponse( |
status_code=400, |
content={"error": "Content did not pass the moral check, please modify and try again."} |
) |
except RequestValidationError: |
logger.error(f"Invalid request body: {parsed_body}") |
pass |
except Exception as e: |
if is_debug: |
import traceback |
traceback.print_exc() |
logger.error(f"Error processing request or performing moral check: {str(e)}") |
try: |
response = await call_next(request) |
if request.url.path.startswith("/v1"): |
if isinstance(response, (FastAPIStreamingResponse, StarletteStreamingResponse)) or type(response).__name__ == '_StreamingResponse': |
response = LoggingStreamingResponse( |
content=response.body_iterator, |
status_code=response.status_code, |
media_type=response.media_type, |
headers=response.headers, |
current_info=current_info, |
) |
elif hasattr(response, 'json'): |
logger.info(f"Response: {await response.json()}") |
else: |
logger.info(f"Response: type={type(response).__name__}, status_code={response.status_code}, headers={response.headers}") |
return response |
finally: |
request_info.reset(current_request_info) |
async def moderate_content(self, content, api_index): |
moderation_request = ModerationRequest(input=content) |
response = await moderations(moderation_request, api_index) |
moderation_result = b"" |
async for chunk in response.body_iterator: |
if isinstance(chunk, str): |
moderation_result += chunk.encode('utf-8') |
else: |
moderation_result += chunk |
moderation_data = json.loads(moderation_result.decode('utf-8')) |
return moderation_data |
app.add_middleware( |
CORSMiddleware, |
allow_origins=["*"], |
allow_credentials=True, |
allow_methods=["*"], |
allow_headers=["*"], |
) |
app.add_middleware(StatsMiddleware) |
class ClientManager: |
def __init__(self, pool_size=100): |
self.pool_size = pool_size |
self.clients = {} |
async def init(self, default_config): |
self.default_config = default_config |
@asynccontextmanager |
async def get_client(self, timeout_value, base_url, proxy=None): |
timeout_value = int(timeout_value) |
parsed_url = urlparse(base_url) |
host = parsed_url.netloc |
client_key = f"{host}_{timeout_value}" |
if proxy: |
proxy_normalized = proxy.replace('socks5h://', 'socks5://') |
client_key += f"_{proxy_normalized}" |
if client_key not in self.clients or IS_VERCEL: |
timeout = httpx.Timeout( |
connect=15.0, |
read=timeout_value, |
write=30.0, |
pool=self.pool_size |
) |
limits = httpx.Limits(max_connections=self.pool_size) |
client_config = { |
**self.default_config, |
"timeout": timeout, |
"limits": limits |
} |
if proxy: |
parsed = urlparse(proxy) |
scheme = parsed.scheme.rstrip('h') |
if scheme == 'socks5': |
try: |
from httpx_socks import AsyncProxyTransport |
proxy = proxy.replace('socks5h://', 'socks5://') |
transport = AsyncProxyTransport.from_url(proxy) |
client_config["transport"] = transport |
except ImportError: |
logger.error("httpx-socks package is required for SOCKS proxy support") |
raise ImportError("Please install httpx-socks package for SOCKS proxy support: pip install httpx-socks") |
else: |
client_config["proxies"] = { |
"http://": proxy, |
"https://": proxy |
} |
self.clients[client_key] = httpx.AsyncClient(**client_config) |
try: |
yield self.clients[client_key] |
except Exception as e: |
if client_key in self.clients: |
tmp_client = self.clients[client_key] |
del self.clients[client_key] |
await tmp_client.aclose() |
raise e |
async def close(self): |
for client in self.clients.values(): |
await client.aclose() |
self.clients.clear() |
@app.middleware("http") |
async def ensure_config(request: Request, call_next): |
if app and not hasattr(app.state, 'config'): |
app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app) |
for item in app.state.api_keys_db: |
if item.get("role") == "admin": |
app.state.admin_api_key = item.get("api") |
if not hasattr(app.state, "admin_api_key"): |
if len(app.state.api_keys_db) >= 1: |
app.state.admin_api_key = app.state.api_keys_db[0].get("api") |
else: |
from utils import yaml_error_message |
if yaml_error_message: |
return JSONResponse( |
status_code=500, |
content={"error": yaml_error_message} |
) |
else: |
return JSONResponse( |
status_code=500, |
content={"error": "No admin API key found"} |
) |
if app and not hasattr(app.state, 'client_manager'): |
default_config = { |
"headers": { |
"User-Agent": "curl/7.68.0", |
"Accept": "*/*", |
}, |
"http2": True, |
"verify": True, |
"follow_redirects": True |
} |
app.state.client_manager = ClientManager(pool_size=200) |
await app.state.client_manager.init(default_config) |
app.state.timeouts = {} |
if app.state.config and 'preferences' in app.state.config: |
for model_name, timeout_value in app.state.config['preferences'].get('model_timeout', {}).items(): |
app.state.timeouts[model_name] = timeout_value |
if "default" not in app.state.config['preferences'].get('model_timeout', {}): |
app.state.timeouts["default"] = DEFAULT_TIMEOUT |
app.state.provider_timeouts = defaultdict(lambda: defaultdict(lambda: DEFAULT_TIMEOUT)) |
for provider in app.state.config["providers"]: |
provider_timeout_settings = safe_get(provider, "preferences", "model_timeout", default={}) |
if provider_timeout_settings: |
for model_name, timeout_value in provider_timeout_settings.items(): |
app.state.provider_timeouts[provider['provider']][model_name] = timeout_value |
app.state.provider_timeouts["global_time_out"] = app.state.timeouts |
if app and not hasattr(app.state, "channel_manager"): |
if app.state.config and 'preferences' in app.state.config: |
COOLDOWN_PERIOD = app.state.config['preferences'].get('cooldown_period', 300) |
else: |
app.state.channel_manager = ChannelManager(cooldown_period=COOLDOWN_PERIOD) |
return await call_next(request) |
def get_timeout_value(provider_timeouts, original_model): |
timeout_value = None |
if original_model in provider_timeouts: |
timeout_value = provider_timeouts[original_model] |
else: |
for timeout_model in provider_timeouts: |
if timeout_model != "default" and timeout_model in original_model: |
timeout_value = provider_timeouts[timeout_model] |
break |
else: |
timeout_value = provider_timeouts.get("default") |
return timeout_value |
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None): |
url = provider['base_url'] |
parsed_url = urlparse(url) |
engine = None |
if parsed_url.path.startswith("/v1beta") or parsed_url.path.startswith("/v1"): |
engine = "gemini" |
elif parsed_url.netloc == 'aiplatform.googleapis.com': |
engine = "vertex" |
elif parsed_url.netloc == 'api.cloudflare.com': |
engine = "cloudflare" |
elif parsed_url.netloc == 'api.anthropic.com' or parsed_url.path.endswith("v1/messages"): |
engine = "claude" |
elif parsed_url.netloc == 'openrouter.ai': |
engine = "openrouter" |
elif parsed_url.netloc == 'api.cohere.com': |
engine = "cohere" |
request.stream = True |
else: |
engine = "gpt" |
model_dict = get_model_dict(provider) |
original_model = model_dict[request.model] |
if "claude" not in original_model \ |
and "gpt" not in original_model \ |
and "gemini" not in original_model \ |
and parsed_url.netloc != 'api.cloudflare.com' \ |
and parsed_url.netloc != 'api.cohere.com': |
engine = "openrouter" |
if "claude" in original_model and engine == "vertex": |
engine = "vertex-claude" |
if "gemini" in original_model and engine == "vertex": |
engine = "vertex-gemini" |
if "o1-preview" in original_model or "o1-mini" in original_model: |
engine = "o1" |
request.stream = False |
if endpoint == "/v1/images/generations": |
engine = "dalle" |
request.stream = False |
if endpoint == "/v1/audio/transcriptions": |
engine = "whisper" |
request.stream = False |
if endpoint == "/v1/moderations": |
engine = "moderation" |
request.stream = False |
if endpoint == "/v1/embeddings": |
engine = "embedding" |
if endpoint == "/v1/audio/speech": |
engine = "tts" |
request.stream = False |
if provider.get("engine"): |
engine = provider["engine"] |
channel_id = f"{provider['provider']}" |
logger.info(f"provider: {channel_id:<11} model: {request.model:<22} engine: {engine}") |
url, headers, payload = await get_payload(request, engine, provider) |
if is_debug: |
logger.info(json.dumps(headers, indent=4, ensure_ascii=False)) |
if payload.get("file"): |
pass |
else: |
logger.info(json.dumps(payload, indent=4, ensure_ascii=False)) |
current_info = request_info.get() |
provider_timeouts = safe_get(app.state.provider_timeouts, channel_id, default=app.state.provider_timeouts["global_time_out"]) |
timeout_value = get_timeout_value(provider_timeouts, original_model) |
if timeout_value is None: |
timeout_value = get_timeout_value(app.state.provider_timeouts["global_time_out"], original_model) |
if timeout_value is None: |
timeout_value = app.state.timeouts.get("default", DEFAULT_TIMEOUT) |
proxy = safe_get(provider, "preferences", "proxy", default=None) |
try: |
async with app.state.client_manager.get_client(timeout_value, url, proxy) as client: |
if request.stream: |
generator = fetch_response_stream(client, url, headers, payload, engine, original_model) |
wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id) |
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream") |
else: |
generator = fetch_response(client, url, headers, payload, engine, original_model) |
wrapped_generator, first_response_time = await error_handling_wrapper(generator, channel_id) |
if endpoint == "/v1/audio/speech": |
if isinstance(wrapped_generator, bytes): |
response = Response(content=wrapped_generator, media_type="audio/mpeg") |
else: |
first_element = await anext(wrapped_generator) |
first_element = first_element.lstrip("data: ") |
first_element = json.loads(first_element) |
response = StarletteStreamingResponse(iter([json.dumps(first_element)]), media_type="application/json") |
await update_channel_stats(current_info["request_id"], channel_id, request.model, current_info["api_key"], success=True) |
current_info["first_response_time"] = first_response_time |
current_info["success"] = True |
current_info["provider"] = channel_id |
return response |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout, httpx.ConnectError) as e: |
await update_channel_stats(current_info["request_id"], channel_id, request.model, current_info["api_key"], success=False) |
raise e |
def weighted_round_robin(weights): |
provider_names = list(weights.keys()) |
current_weights = {name: 0 for name in provider_names} |
num_selections = total_weight = sum(weights.values()) |
weighted_provider_list = [] |
for _ in range(num_selections): |
max_ratio = -1 |
selected_letter = None |
for name in provider_names: |
current_weights[name] += weights[name] |
ratio = current_weights[name] / weights[name] |
if ratio > max_ratio: |
max_ratio = ratio |
selected_letter = name |
weighted_provider_list.append(selected_letter) |
current_weights[selected_letter] -= total_weight |
return weighted_provider_list |
import random |
def lottery_scheduling(weights): |
total_tickets = sum(weights.values()) |
selections = [] |
for _ in range(total_tickets): |
ticket = random.randint(1, total_tickets) |
cumulative = 0 |
for provider, weight in weights.items(): |
cumulative += weight |
if ticket <= cumulative: |
selections.append(provider) |
break |
return selections |
def get_provider_rules(model_rule, config, request_model): |
provider_rules = [] |
if model_rule == "all": |
for provider in config["providers"]: |
model_dict = get_model_dict(provider) |
for model in model_dict.keys(): |
provider_rules.append(provider["provider"] + "/" + model) |
elif "/" in model_rule: |
if model_rule.startswith("<") and model_rule.endswith(">"): |
model_rule = model_rule[1:-1] |
for provider in config['providers']: |
model_dict = get_model_dict(provider) |
if model_rule in model_dict.keys(): |
provider_rules.append(provider['provider'] + "/" + model_rule) |
else: |
provider_name = model_rule.split("/")[0] |
model_name_split = "/".join(model_rule.split("/")[1:]) |
models_list = [] |
for provider in config['providers']: |
model_dict = get_model_dict(provider) |
if provider['provider'] == provider_name: |
models_list.extend(list(model_dict.keys())) |
if model_name_split == "*": |
if request_model in models_list: |
provider_rules.append(provider_name + "/" + request_model) |
for models_list_model in models_list: |
if request_model.endswith("*") and models_list_model.startswith(request_model.rstrip("*")): |
provider_rules.append(provider_name + "/" + models_list_model) |
elif model_name_split == request_model \ |
or (request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*"))): |
if model_name_split in models_list: |
provider_rules.append(provider_name + "/" + model_name_split) |
else: |
for provider in config["providers"]: |
model_dict = get_model_dict(provider) |
if model_rule in model_dict.keys(): |
provider_rules.append(provider["provider"] + "/" + model_rule) |
return provider_rules |
def get_provider_list(provider_rules, config, request_model): |
provider_list = [] |
for item in provider_rules: |
for provider in config['providers']: |
model_dict = get_model_dict(provider) |
model_name_split = "/".join(item.split("/")[1:]) |
if "/" in item and provider['provider'] == item.split("/")[0] and model_name_split in model_dict.keys(): |
new_provider = copy.deepcopy(provider) |
new_provider["model"] = [{model_dict[model_name_split]: request_model}] |
if request_model in model_dict.keys() and model_name_split == request_model: |
provider_list.append(new_provider) |
elif request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*")): |
provider_list.append(new_provider) |
return provider_list |
def get_matching_providers(request_model, config, api_index): |
provider_rules = [] |
for model_rule in config['api_keys'][api_index]['model']: |
provider_rules.extend(get_provider_rules(model_rule, config, request_model)) |
provider_list = get_provider_list(provider_rules, config, request_model) |
return provider_list |
async def get_right_order_providers(request_model, config, api_index, scheduling_algorithm): |
matching_providers = get_matching_providers(request_model, config, api_index) |
if not matching_providers: |
raise HTTPException(status_code=404, detail=f"No matching model found: {request_model}") |
num_matching_providers = len(matching_providers) |
if app.state.channel_manager.cooldown_period > 0 and num_matching_providers > 1: |
matching_providers = await app.state.channel_manager.get_available_providers(matching_providers) |
if not matching_providers: |
raise HTTPException(status_code=503, detail="No available providers at the moment") |
if scheduling_algorithm == "random": |
matching_providers = random.sample(matching_providers, num_matching_providers) |
weights = safe_get(config, 'api_keys', api_index, "weights") |
if weights: |
intersection = None |
all_providers = set(provider['provider'] + "/" + request_model for provider in matching_providers) |
if all_providers: |
weight_keys = set(weights.keys()) |
provider_rules = [] |
for model_rule in weight_keys: |
provider_rules.extend(get_provider_rules(model_rule, config, request_model)) |
provider_list = get_provider_list(provider_rules, config, request_model) |
weight_keys = set([provider['provider'] + "/" + request_model for provider in provider_list]) |
intersection = all_providers.intersection(weight_keys) |
if len(intersection) == 1: |
intersection = None |
if intersection: |
filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k.split("/")[0] + "/" + request_model in intersection} |
if scheduling_algorithm == "weighted_round_robin": |
weighted_provider_name_list = weighted_round_robin(filtered_weights) |
elif scheduling_algorithm == "lottery": |
weighted_provider_name_list = lottery_scheduling(filtered_weights) |
else: |
weighted_provider_name_list = list(filtered_weights.keys()) |
new_matching_providers = [] |
for provider_name in weighted_provider_name_list: |
for provider in matching_providers: |
if provider['provider'] == provider_name: |
new_matching_providers.append(provider) |
matching_providers = new_matching_providers |
if is_debug: |
for provider in matching_providers: |
logger.info("available provider: %s", json.dumps(provider, indent=4, ensure_ascii=False, default=circular_list_encoder)) |
return matching_providers |
import asyncio |
class ModelRequestHandler: |
def __init__(self): |
self.last_provider_indices = defaultdict(lambda: -1) |
self.locks = defaultdict(asyncio.Lock) |
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], api_index: int = None, endpoint=None): |
config = app.state.config |
request_model = request.model |
if not safe_get(config, 'api_keys', api_index, 'model'): |
raise HTTPException(status_code=404, detail=f"No matching model found: {request_model}") |
scheduling_algorithm = safe_get(config, 'api_keys', api_index, "preferences", "SCHEDULING_ALGORITHM", default="fixed_priority") |
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm) |
num_matching_providers = len(matching_providers) |
status_code = 500 |
error_message = None |
start_index = 0 |
if scheduling_algorithm != "fixed_priority": |
async with self.locks[request_model]: |
self.last_provider_indices[request_model] = (self.last_provider_indices[request_model] + 1) % num_matching_providers |
start_index = self.last_provider_indices[request_model] |
auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True) |
index = 0 |
if num_matching_providers == 1 and (count := provider_api_circular_list[matching_providers[0]['provider']].get_items_count()) > 1: |
retry_count = count |
else: |
retry_count = int(auto_retry) |
while True: |
if index >= num_matching_providers + retry_count: |
break |
current_index = (start_index + index) % num_matching_providers |
index += 1 |
provider = matching_providers[current_index] |
try: |
response = await process_request(request, provider, endpoint) |
return response |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout, httpx.ConnectError) as e: |
if isinstance(e, httpx.ReadTimeout): |
status_code = 504 |
error_message = "Request timed out" |
elif isinstance(e, httpx.ConnectError): |
status_code = 503 |
error_message = "Unable to connect to service" |
elif isinstance(e, httpx.ReadError): |
status_code = 502 |
error_message = "Network read error" |
elif isinstance(e, httpx.RemoteProtocolError): |
status_code = 502 |
error_message = "Remote protocol error" |
elif isinstance(e, asyncio.CancelledError): |
status_code = 499 |
error_message = "Request was cancelled" |
elif isinstance(e, HTTPException): |
status_code = e.status_code |
error_message = str(e.detail) |
else: |
status_code = 500 |
error_message = str(e) or f"Unknown error: {e.__class__.__name__}" |
channel_id = f"{provider['provider']}" |
if app.state.channel_manager.cooldown_period > 0 and num_matching_providers > 1: |
await app.state.channel_manager.exclude_model(channel_id, request_model) |
matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm) |
last_num_matching_providers = num_matching_providers |
num_matching_providers = len(matching_providers) |
if num_matching_providers != last_num_matching_providers: |
index = 0 |
cooling_time = safe_get(provider, "preferences", "api_key_cooldown_period", default=0) |
api_key_count = provider_api_circular_list[channel_id].get_items_count() |
current_api = await provider_api_circular_list[channel_id].after_next_current() |
if cooling_time > 0 and api_key_count > 1: |
await provider_api_circular_list[channel_id].set_cooling(current_api, cooling_time=cooling_time) |
logger.error(f"Error {status_code} with provider {channel_id} API key: {current_api}: {error_message}") |
if is_debug: |
import traceback |
traceback.print_exc() |
if auto_retry: |
continue |
else: |
return JSONResponse( |
status_code=status_code, |
content={"error": f"Error: Current provider response failed: {error_message}"} |
) |
current_info = request_info.get() |
current_info["first_response_time"] = -1 |
current_info["success"] = False |
current_info["provider"] = None |
return JSONResponse( |
status_code=status_code, |
content={"error": f"All {request.model} error: {error_message}"} |
) |
model_handler = ModelRequestHandler() |
security = HTTPBearer() |
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)): |
token = credentials.credentials if credentials else None |
api_list = app.state.api_list |
try: |
api_index = api_list.index(token) |
except ValueError: |
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None) |
if api_index is None: |
print("error: Invalid or missing API Key:", token) |
api_index = None |
token = None |
client_ip = request.client.host |
rate_limit_key = f"{client_ip}:{token}" if token else client_ip |
limits = await get_user_rate_limit(app, api_index) |
if await rate_limiter.is_rate_limited(rate_limit_key, limits): |
raise HTTPException(status_code=429, detail="Too many requests") |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
api_list = app.state.api_list |
token = credentials.credentials |
api_index = None |
try: |
api_index = api_list.index(token) |
except ValueError: |
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None) |
if api_index is None: |
raise HTTPException(status_code=403, detail="Invalid or missing API Key") |
return api_index |
def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
api_list = app.state.api_list |
token = credentials.credentials |
api_index = None |
try: |
api_index = api_list.index(token) |
except ValueError: |
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None) |
if api_index is None: |
raise HTTPException(status_code=403, detail="Invalid or missing API Key") |
if app.state.api_keys_db[api_index].get('role') != "admin": |
raise HTTPException(status_code=403, detail="Permission denied") |
return token |
v1_router = APIRouter(prefix="/api") |
@v1_router.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)]) |
async def request_model(request: RequestModel, api_index: int = Depends(verify_api_key)): |
return await model_handler.request_model(request, api_index) |
@v1_router.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)]) |
async def options_handler(): |
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"}) |
@v1_router.get("/v1/models", dependencies=[Depends(rate_limit_dependency)]) |
async def list_models(api_index: int = Depends(verify_api_key)): |
models = post_all_models(api_index, app.state.config) |
return JSONResponse(content={ |
"object": "list", |
"data": models |
}) |
@v1_router.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)]) |
async def images_generations( |
request: ImageGenerationRequest, |
api_index: int = Depends(verify_api_key) |
): |
return await model_handler.request_model(request, api_index, endpoint="/v1/images/generations") |
@v1_router.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)]) |
async def embeddings( |
request: EmbeddingRequest, |
api_index: int = Depends(verify_api_key) |
): |
return await model_handler.request_model(request, api_index, endpoint="/v1/embeddings") |
@v1_router.post("/v1/audio/speech", dependencies=[Depends(rate_limit_dependency)]) |
async def audio_speech( |
request: TextToSpeechRequest, |
api_index: str = Depends(verify_api_key) |
): |
return await model_handler.request_model(request, api_index, endpoint="/v1/audio/speech") |
@v1_router.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)]) |
async def moderations( |
request: ModerationRequest, |
api_index: int = Depends(verify_api_key) |
): |
return await model_handler.request_model(request, api_index, endpoint="/v1/moderations") |
from fastapi import UploadFile, File, Form, HTTPException |
import io |
@v1_router.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)]) |
async def audio_transcriptions( |
file: UploadFile = File(...), |
model: str = Form(...), |
api_index: int = Depends(verify_api_key) |
): |
try: |
content = await file.read() |
file_obj = io.BytesIO(content) |
request = AudioTranscriptionRequest( |
file=(file.filename, file_obj, file.content_type), |
model=model |
) |
return await model_handler.request_model(request, api_index, endpoint="/v1/audio/transcriptions") |
except UnicodeDecodeError: |
raise HTTPException(status_code=400, detail="Invalid audio file encoding") |
except Exception as e: |
if is_debug: |
import traceback |
traceback.print_exc() |
raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}") |
@v1_router.get("/v1/generate-api-key", dependencies=[Depends(rate_limit_dependency)]) |
def generate_api_key(): |
chars = string.ascii_letters + string.digits |
random_string = ''.join(secrets.choice(chars) for _ in range(36)) |
api_key = "sk-" + random_string |
return JSONResponse(content={"api_key": api_key}) |
from datetime import datetime, timedelta, timezone |
from sqlalchemy import func, desc, case |
from fastapi import Query |
@v1_router.get("/v1/stats", dependencies=[Depends(rate_limit_dependency)]) |
async def get_stats( |
request: Request, |
token: str = Depends(verify_admin_api_key), |
hours: int = Query(default=24, ge=1, le=720, description="Number of hours to look back for stats (1-720)") |
): |
''' |
## 获取统计数据 |
使用 `/v1/stats` 获取最近 24 小时各个渠道的使用情况统计。同时带上 自己的 uni-api 的 admin API key。 |
数据包括: |
1. 每个渠道下面每个模型的成功率,成功率从高到低排序。 |
2. 每个渠道总的成功率,成功率从高到低排序。 |
3. 每个模型在所有渠道总的请求次数。 |
4. 每个端点的请求次数。 |
5. 每个ip请求的次数。 |
`/v1/stats?hours=48` 参数 `hours` 可以控制返回最近多少小时的数据统计,不传 `hours` 这个参数,默认统计最近 24 小时的统计数据。 |
还有其他统计数据,可以自己写sql在数据库自己查。其他数据包括:首字时间,每个请求的总处理时间,每次请求是否成功,每次请求是否符合道德审查,每次请求的文本内容,每次请求的 API key,每次请求的输入 token,输出 token 数量。 |
''' |
return JSONResponse(content={"stats": {}}) |
async with async_session() as session: |
start_time = datetime.now(timezone.utc) - timedelta(hours=hours) |
channel_model_stats = await session.execute( |
select( |
ChannelStat.provider, |
ChannelStat.model, |
func.count().label('total'), |
func.sum(case((ChannelStat.success == True, 1), else_=0)).label('success_count') |
) |
.where(ChannelStat.timestamp >= start_time) |
.group_by(ChannelStat.provider, ChannelStat.model) |
) |
channel_model_stats = channel_model_stats.fetchall() |
channel_stats = await session.execute( |
select( |
ChannelStat.provider, |
func.count().label('total'), |
func.sum(case((ChannelStat.success == True, 1), else_=0)).label('success_count') |
) |
.where(ChannelStat.timestamp >= start_time) |
.group_by(ChannelStat.provider) |
) |
channel_stats = channel_stats.fetchall() |
model_stats = await session.execute( |
select(RequestStat.model, func.count().label('count')) |
.where(RequestStat.timestamp >= start_time) |
.group_by(RequestStat.model) |
.order_by(desc('count')) |
) |
model_stats = model_stats.fetchall() |
endpoint_stats = await session.execute( |
select(RequestStat.endpoint, func.count().label('count')) |
.where(RequestStat.timestamp >= start_time) |
.group_by(RequestStat.endpoint) |
.order_by(desc('count')) |
) |
endpoint_stats = endpoint_stats.fetchall() |
ip_stats = await session.execute( |
select(RequestStat.client_ip, func.count().label('count')) |
.where(RequestStat.timestamp >= start_time) |
.group_by(RequestStat.client_ip) |
.order_by(desc('count')) |
) |
ip_stats = ip_stats.fetchall() |
stats = { |
"time_range": f"Last {hours} hours", |
"channel_model_success_rates": [ |
{ |
"provider": stat.provider, |
"model": stat.model, |
"success_rate": stat.success_count / stat.total if stat.total > 0 else 0, |
"total_requests": stat.total |
} for stat in sorted(channel_model_stats, key=lambda x: x.success_count / x.total if x.total > 0 else 0, reverse=True) |
], |
"channel_success_rates": [ |
{ |
"provider": stat.provider, |
"success_rate": stat.success_count / stat.total if stat.total > 0 else 0, |
"total_requests": stat.total |
} for stat in sorted(channel_stats, key=lambda x: x.success_count / x.total if x.total > 0 else 0, reverse=True) |
], |
"model_request_counts": [ |
{ |
"model": stat.model, |
"count": stat.count |
} for stat in model_stats |
], |
"endpoint_request_counts": [ |
{ |
"endpoint": stat.endpoint, |
"count": stat.count |
} for stat in endpoint_stats |
], |
"ip_request_counts": [ |
{ |
"ip": stat.client_ip, |
"count": stat.count |
} for stat in ip_stats |
] |
} |
return JSONResponse(content=stats) |
from fastapi import FastAPI, Request |
from fastapi import Form as FastapiForm, HTTPException, Depends |
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse |
from fastapi.security import APIKeyHeader |
from typing import Optional, List |
from xue import HTML, Head, Body, Div, xue_initialize, Script, Ul, Li |
from xue.components import input, dropdown, sheet, form, button, checkbox, sidebar, chart |
from xue.components.model_config_row import model_config_row |
from components.provider_table import data_table |
from ruamel.yaml import YAML |
yaml = YAML() |
yaml.preserve_quotes = True |
yaml.indent(mapping=2, sequence=4, offset=2) |
frontend_router = APIRouter() |
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) |
async def get_api_key(request: Request, x_api_key: Optional[str] = Depends(api_key_header)): |
if not x_api_key: |
x_api_key = request.cookies.get("x_api_key") or request.query_params.get("x_api_key") |
if not hasattr(app.state, 'config'): |
await ensure_config(request, lambda: None) |
if x_api_key == app.state.admin_api_key: |
return x_api_key |
else: |
return None |
async def frontend_rate_limit_dependency(request: Request, x_api_key: str = Depends(get_api_key)): |
token = x_api_key if x_api_key else None |
client_ip = request.client.host |
rate_limit_key = f"{client_ip}:{token}" if token else client_ip |
limits = [(100, 60)] |
if await rate_limiter.is_rate_limited(rate_limit_key, limits): |
raise HTTPException(status_code=429, detail="Too many requests") |
xue_initialize(tailwind=True) |
data_table_columns = [ |
{"label": "Provider", "value": "provider", "sortable": True}, |
{"label": "Base url", "value": "base_url", "sortable": True}, |
{"label": "Tools", "value": "tools", "sortable": True}, |
] |
@frontend_router.get("/login", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def login_page(): |
return HTML( |
Head(title="登录"), |
Body( |
Div( |
form.Form( |
form.FormField("API Key", "x_api_key", type="password", placeholder="输入API密钥", required=True), |
Div(id="error-message", class_="text-red-500 mt-2"), |
Div( |
button.button("提交", variant="primary", type="submit"), |
class_="flex justify-end mt-4" |
), |
hx_post="/verify-api-key", |
hx_target="#error-message", |
hx_swap="innerHTML", |
class_="space-y-4" |
), |
class_="container mx-auto p-4 max-w-md" |
) |
) |
).render() |
@frontend_router.post("/verify-api-key", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def verify_api_key(x_api_key: str = FastapiForm(...)): |
if x_api_key == app.state.admin_api_key: |
response = JSONResponse(content={"success": True}) |
response.headers["HX-Redirect"] = "/" |
response.set_cookie( |
key="x_api_key", |
value=x_api_key, |
httponly=True, |
max_age=1800, |
secure=False, |
samesite="lax" |
) |
return response |
else: |
return Div("无效的API密钥", class_="text-red-500").render() |
sidebar_items = [ |
{ |
"icon": "layout-dashboard", |
"label": "Dashboard", |
"value": "dashboard", |
"hx": {"get": "/dashboard", "target": "#main-content"} |
}, |
{ |
"icon": "database", |
"label": "Data", |
"value": "data", |
"hx": {"get": "/data", "target": "#main-content"} |
}, |
] |
@frontend_router.get("/", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def root(x_api_key: str = Depends(get_api_key)): |
if not x_api_key: |
return RedirectResponse(url="/login", status_code=303) |
result = HTML( |
Head( |
Script(""" |
document.addEventListener('DOMContentLoaded', function() { |
const filterInput = document.getElementById('users-table-filter'); |
filterInput.addEventListener('input', function() { |
const filterValue = this.value; |
htmx.ajax('GET', `/filter-table?filter=${filterValue}`, '#users-table'); |
}); |
}); |
"""), |
title="uni-api" |
), |
Body( |
Div( |
sidebar.Sidebar("zap", "uni-api", sidebar_items, is_collapsed=False, active_item="dashboard"), |
Div( |
Div( |
data_table(data_table_columns, app.state.config["providers"], "users-table"), |
class_="p-4" |
), |
Div(id="sheet-container"), |
id="main-content", |
class_="ml-[200px] p-6 transition-[margin] duration-200 ease-in-out" |
), |
class_="flex" |
), |
class_="container mx-auto", |
id="body" |
) |
).render() |
return result |
@frontend_router.get("/sidebar/toggle", response_class=HTMLResponse) |
async def toggle_sidebar(is_collapsed: bool = False): |
return sidebar.Sidebar( |
"zap", |
"uni-api", |
sidebar_items, |
is_collapsed=not is_collapsed, |
active_item="dashboard" |
).render() |
@app.get("/sidebar/update/{active_item}", response_class=HTMLResponse) |
async def update_sidebar(active_item: str): |
return sidebar.Sidebar( |
"zap", |
"uni-api", |
sidebar_items, |
is_collapsed=False, |
active_item=active_item |
).render() |
@frontend_router.get("/dashboard", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def data_page(x_api_key: str = Depends(get_api_key)): |
if not x_api_key: |
return RedirectResponse(url="/login", status_code=303) |
result = Div( |
Div( |
data_table(data_table_columns, app.state.config["providers"], "users-table"), |
class_="p-4" |
), |
Div(id="sheet-container"), |
id="main-content", |
class_="ml-[200px] p-6 transition-[margin] duration-200 ease-in-out" |
).render() |
return result |
@frontend_router.get("/data", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def data_page(x_api_key: str = Depends(get_api_key)): |
if not x_api_key: |
return RedirectResponse(url="/login", status_code=303) |
return HTMLResponse("数据库已禁用") |
async with async_session() as session: |
start_time = datetime.now(timezone.utc) - timedelta(hours=24) |
model_stats = await session.execute( |
select( |
func.strftime('%H', RequestStat.timestamp).label('hour'), |
RequestStat.model, |
func.count().label('count') |
) |
.where(RequestStat.timestamp >= start_time) |
.group_by('hour', RequestStat.model) |
.order_by('hour') |
) |
model_stats = model_stats.fetchall() |
models = list(set(stat.model for stat in model_stats)) |
chart_data = [] |
current_hour = datetime.now().hour |
for i in range(24): |
hour = (current_hour - i) % 24 |
hour_str = f"{hour:02d}" |
data_point = {"label": hour_str} |
for model in models: |
count = next( |
(stat.count for stat in model_stats |
if stat.hour == f"{hour:02d}" and stat.model == model), |
0 |
) |
data_point[model] = count |
chart_data.append(data_point) |
chart_data.reverse() |
chart_config = { |
model: { |
"label": model, |
"color": f"hsl({i * 360 / len(models)}, 70%, 50%)" |
} |
for i, model in enumerate(models) |
} |
result = HTML( |
Head(title="数据统计"), |
Body( |
Div( |
Div( |
"模型使用统计 (24小时) - 按小时统计", |
class_="text-2xl font-bold mb-4" |
), |
Div( |
chart.chart( |
chart_data, |
chart_config, |
stacked=True, |
), |
class_="mb-8 h-[400px]" |
), |
id="main-content", |
class_="container ml-[200px] mx-auto p-4" |
) |
) |
).render() |
print(result) |
return result |
@frontend_router.get("/dropdown-menu/{menu_id}/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def get_columns_menu(menu_id: str, row_id: str): |
columns = [ |
{ |
"label": "Edit", |
"value": "edit", |
"hx-get": f"/edit-sheet/{row_id}", |
"hx-target": "#sheet-container", |
"hx-swap": "innerHTML" |
}, |
{ |
"label": "Duplicate", |
"value": "duplicate", |
"hx-post": f"/duplicate/{row_id}", |
"hx-target": "body", |
"hx-swap": "outerHTML" |
}, |
{ |
"label": "Delete", |
"value": "delete", |
"hx-delete": f"/delete/{row_id}", |
"hx-target": "body", |
"hx-swap": "outerHTML", |
"hx-confirm": "Are you sure you want to delete this configuration?" |
}, |
] |
result = dropdown.dropdown_menu_content(menu_id, columns).render() |
print(result) |
return result |
@frontend_router.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def get_columns_menu(menu_id: str): |
result = dropdown.dropdown_menu_content(menu_id, data_table_columns).render() |
print(result) |
return result |
@frontend_router.get("/filter-table", response_class=HTMLResponse) |
async def filter_table(filter: str = ""): |
filtered_data = [ |
(i, provider) for i, provider in enumerate(app.state.config["providers"]) |
if filter.lower() in str(provider["provider"]).lower() or |
filter.lower() in str(provider["base_url"]).lower() or |
filter.lower() in str(provider["tools"]).lower() |
] |
return data_table(data_table_columns, [p for _, p in filtered_data], "users-table", with_filter=False, row_ids=[i for i, _ in filtered_data]).render() |
@frontend_router.post("/add-model", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def add_model(): |
new_model_id = f"model{hash(str(time()))}" |
new_model = model_config_row(new_model_id).render() |
return new_model |
def render_api_keys(row_id, api_keys): |
return Ul( |
*[Li( |
Div( |
Div( |
input.input( |
type="text", |
placeholder="Enter API key", |
value=api_key, |
name=f"api_key_{i}", |
class_="flex-grow w-full" |
), |
class_="flex-grow" |
), |
button.button( |
"Delete", |
variant="outline", |
type="button", |
class_="ml-2", |
hx_delete=f"/delete-api-key/{row_id}/{i}", |
hx_target="#api-keys-container", |
hx_swap="outerHTML" |
), |
class_="flex items-center mb-2 w-full" |
) |
) for i, api_key in enumerate(api_keys)], |
id="api-keys-container", |
class_="space-y-2 w-full" |
) |
@frontend_router.get("/edit-sheet/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def get_edit_sheet(row_id: str, x_api_key: str = Depends(get_api_key)): |
row_data = get_row_data(row_id) |
print("row_data", row_data) |
model_list = [] |
for index, model in enumerate(row_data["model"]): |
if isinstance(model, str): |
model_list.append(model_config_row(f"model{index}", model, "", True)) |
if isinstance(model, dict): |
key, value = list(model.items())[0] |
model_list.append(model_config_row(f"model{index}", key, value, True)) |
api_keys = row_data["api"] if isinstance(row_data["api"], list) else [row_data["api"]] |
api_key_inputs = render_api_keys(row_id, api_keys) |
sheet_id = "edit-sheet" |
edit_sheet_content = sheet.SheetContent( |
sheet.SheetHeader( |
sheet.SheetTitle("Edit Item"), |
sheet.SheetDescription("Make changes to your item here.") |
), |
sheet.SheetBody( |
Div( |
form.Form( |
form.FormField("Provider", "provider", value=row_data["provider"], placeholder="Enter provider name", required=True), |
form.FormField("Base URL", "base_url", value=row_data["base_url"], placeholder="Enter base URL", required=True), |
Div( |
Div("API Keys", class_="text-lg font-semibold mb-2"), |
api_key_inputs, |
button.button( |
"Add API Key", |
class_="mt-2", |
hx_post=f"/add-api-key/{row_id}", |
hx_target="#api-keys-container", |
hx_swap="outerHTML" |
), |
class_="mb-4" |
), |
Div( |
Div("Models", class_="text-lg font-semibold mb-2"), |
Div( |
*model_list, |
id="models-container", |
class_="space-y-2 max-h-[40vh] overflow-y-auto" |
), |
button.button( |
"Add Model", |
class_="mt-2", |
hx_post="/add-model", |
hx_target="#models-container", |
hx_swap="beforeend" |
), |
class_="mb-4" |
), |
Div( |
checkbox.checkbox("tools", "Enable Tools", checked=row_data["tools"], name="tools"), |
class_="mb-4" |
), |
form.FormField("Notes", "notes", value=row_data.get("notes", ""), placeholder="Enter any additional notes"), |
Div( |
button.button("Submit", variant="primary", type="submit"), |
button.button("Cancel", variant="outline", type="button", class_="ml-2", onclick=f"toggleSheet('{sheet_id}')"), |
class_="flex justify-end mt-4" |
), |
hx_post=f"/submit/{row_id}", |
hx_swap="outerHTML", |
hx_target="body", |
class_="space-y-4" |
), |
class_="container mx-auto p-4 max-w-2xl" |
) |
), |
class_="max-h-[90vh] overflow-y-auto" |
) |
result = sheet.Sheet( |
sheet_id, |
Div(), |
edit_sheet_content, |
width="80%", |
max_width="800px" |
).render() |
return result |
@frontend_router.post("/add-api-key/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def add_api_key(row_id: str): |
row_data = get_row_data(row_id) |
api_keys = row_data["api"] if isinstance(row_data["api"], list) else [row_data["api"]] |
api_keys.append("") |
api_key_inputs = render_api_keys(row_id, api_keys) |
return api_key_inputs.render() |
@frontend_router.delete("/delete-api-key/{row_id}/{index}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def delete_api_key(row_id: str, index: int): |
row_data = get_row_data(row_id) |
api_keys = row_data["api"] if isinstance(row_data["api"], list) else [row_data["api"]] |
if len(api_keys) > 1: |
del api_keys[index] |
api_key_inputs = render_api_keys(row_id, api_keys) |
return api_key_inputs.render() |
@frontend_router.get("/add-provider-sheet", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def get_add_provider_sheet(): |
sheet_id = "add-provider-sheet" |
edit_sheet_content = sheet.SheetContent( |
sheet.SheetHeader( |
sheet.SheetTitle("Add New Provider"), |
sheet.SheetDescription("Enter details for the new provider.") |
), |
sheet.SheetBody( |
Div( |
form.Form( |
form.FormField("Provider", "provider", placeholder="Enter provider name", required=True), |
form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True), |
form.FormField("API Key", "api_key", type="text", placeholder="Enter API key"), |
Div( |
Div("Models", class_="text-lg font-semibold mb-2"), |
Div(id="models-container"), |
button.button( |
"Add Model", |
class_="mt-2", |
hx_post="/add-model", |
hx_target="#models-container", |
hx_swap="beforeend" |
), |
class_="mb-4" |
), |
Div( |
checkbox.checkbox("tools", "Enable Tools", name="tools"), |
class_="mb-4" |
), |
form.FormField("Notes", "notes", placeholder="Enter any additional notes"), |
Div( |
button.button("Submit", variant="primary", type="submit"), |
button.button("Cancel", variant="outline", type="button", class_="ml-2", onclick=f"toggleSheet('{sheet_id}')"), |
class_="flex justify-end mt-4" |
), |
hx_post="/submit/new", |
hx_swap="outerHTML", |
hx_target="body", |
class_="space-y-4" |
), |
class_="container mx-auto p-4 max-w-2xl" |
) |
) |
) |
result = sheet.Sheet( |
sheet_id, |
Div(), |
edit_sheet_content, |
width="80%", |
max_width="800px" |
).render() |
return result |
def get_row_data(row_id): |
index = int(row_id) |
return app.state.config["providers"][index] |
def update_row_data(row_id, updated_data): |
print(row_id, updated_data) |
index = int(row_id) |
app.state.config["providers"][index] = updated_data |
@frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def submit_form( |
row_id: str, |
request: Request, |
provider: str = FastapiForm(...), |
base_url: str = FastapiForm(...), |
tools: Optional[str] = FastapiForm(None), |
notes: Optional[str] = FastapiForm(None), |
x_api_key: str = Depends(get_api_key) |
): |
form_data = await request.form() |
api_keys = [value for key, value in form_data.items() if key.startswith("api_key_") and value] |
models = [] |
for key, value in form_data.items(): |
if key.startswith("model_name_"): |
model_id = key.split("_")[-1] |
enabled = form_data.get(f"model_enabled_{model_id}") == "on" |
rename = form_data.get(f"model_rename_{model_id}") |
if value: |
if rename: |
models.append({value: rename}) |
else: |
models.append(value) |
updated_data = { |
"provider": provider, |
"base_url": base_url, |
"api": api_keys[0] if len(api_keys) == 1 else api_keys, |
"model": models, |
"tools": tools == "on", |
"notes": notes, |
} |
print("updated_data", updated_data) |
if row_id == "new": |
app.state.config["providers"].append(updated_data) |
else: |
update_row_data(row_id, updated_data) |
save_api_yaml(app.state.config) |
return await root() |
@frontend_router.post("/duplicate/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def duplicate_row(row_id: str): |
index = int(row_id) |
original_data = app.state.config["providers"][index] |
new_data = original_data.copy() |
new_data["provider"] += "-copy" |
app.state.config["providers"].insert(index + 1, new_data) |
save_api_yaml(app.state.config) |
return await root() |
@frontend_router.delete("/delete/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) |
async def delete_row(row_id: str): |
index = int(row_id) |
del app.state.config["providers"][index] |
save_api_yaml(app.state.config) |
return await root() |
app.include_router(v1_router, tags=["v1"]) |
app.include_router(frontend_router, tags=["frontend"]) |
if __name__ == '__main__': |
import uvicorn |
uvicorn.run( |
"__main__:app", |
host="", |
port=8000, |
reload=True, |
reload_dirs=["./"], |
reload_includes=["*.py", "api.yaml"], |
ws="none", |
) |