uni-api / main.py
yym68686's picture
Support downloading configuration files from external URLs.
b21d1ea
raw
history blame
7.85 kB
from log_config import logger
import httpx
import secrets
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Depends
from models import RequestModel
from utils import config, api_keys_db, api_list, error_handling_wrapper, get_all_models, verify_api_key, post_all_models, update_config
from request import get_payload
from response import fetch_response, fetch_response_stream
from typing import List, Dict
from urllib.parse import urlparse
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时的代码
timeout = httpx.Timeout(connect=15.0, read=10.0, write=30.0, pool=30.0)
app.state.client = httpx.AsyncClient(timeout=timeout)
import os
import yaml
# 新增: 从环境变量获取配置URL并拉取配置
config_url = os.environ.get('CONFIG_URL')
if config_url:
try:
response = await app.state.client.get(config_url)
response.raise_for_status()
config_data = yaml.safe_load(response.text)
# 更新配置
global config, api_keys_db, api_list
config, api_keys_db, api_list = update_config(config_data)
except Exception as e:
logger.error(f"Error fetching or parsing config from {config_url}: {str(e)}")
yield
# 关闭时的代码
await app.state.client.aclose()
app = FastAPI(lifespan=lifespan)
# 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源
allow_credentials=True,
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有头部字段
)
async def process_request(request: RequestModel, provider: Dict):
url = provider['base_url']
parsed_url = urlparse(url)
# print(parsed_url)
engine = None
if parsed_url.netloc == 'generativelanguage.googleapis.com':
engine = "gemini"
elif parsed_url.netloc == 'api.anthropic.com' or parsed_url.path.endswith("v1/messages"):
engine = "claude"
elif parsed_url.netloc == 'openrouter.ai':
engine = "openrouter"
else:
engine = "gpt"
if "claude" not in provider['model'][request.model] \
and "gpt" not in provider['model'][request.model] \
and "gemini" not in provider['model'][request.model]:
engine = "openrouter"
if provider.get("engine"):
engine = provider["engine"]
logger.info(f"provider: {provider['provider']:<10} engine: {engine}")
url, headers, payload = await get_payload(request, engine, provider)
# request_info = {
# "url": url,
# "headers": headers,
# "payload": payload
# }
# print(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")
if request.stream:
model = provider['model'][request.model]
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
return StreamingResponse(wrapped_generator, media_type="text/event-stream")
else:
return await fetch_response(app.state.client, url, headers, payload)
class ModelRequestHandler:
def __init__(self):
self.last_provider_index = -1
def get_matching_providers(self, model_name, token):
api_index = api_list.index(token)
provider_rules = []
for model in config['api_keys'][api_index]['model']:
if "/" in model:
provider_name = model.split("/")[0]
model = model.split("/")[1]
for provider in config['providers']:
if provider['provider'] == provider_name:
models_list = provider['model'].keys()
if (model and model_name in models_list) or (model == "*" and model_name in models_list):
provider_rules.append(provider_name)
else:
for provider in config['providers']:
if model in provider['model'].keys():
provider_rules.append(provider['provider'] + "/" + model)
provider_list = []
# print("provider_rules", provider_rules)
for provider in config['providers']:
for item in provider_rules:
if provider['provider'] in item:
if "/" in item:
if item.split("/")[1] == model_name:
provider_list.append(provider)
else:
if model_name in provider['model'].keys():
provider_list.append(provider)
return provider_list
async def request_model(self, request: RequestModel, token: str):
model_name = request.model
matching_providers = self.get_matching_providers(model_name, token)
# print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False))
if not matching_providers:
raise HTTPException(status_code=404, detail="No matching model found")
# 检查是否启用轮询
api_index = api_list.index(token)
use_round_robin = False
auto_retry = False
if config['api_keys'][api_index].get("preferences"):
use_round_robin = config['api_keys'][api_index]["preferences"].get("USE_ROUND_ROBIN")
auto_retry = config['api_keys'][api_index]["preferences"].get("AUTO_RETRY")
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry)
async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool, auto_retry: bool):
num_providers = len(providers)
start_index = self.last_provider_index + 1 if use_round_robin else 0
for i in range(num_providers + 1):
self.last_provider_index = (start_index + i) % num_providers
provider = providers[self.last_provider_index]
try:
response = await process_request(request, provider)
return response
except (Exception, HTTPException) as e:
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
if auto_retry:
continue
else:
raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
raise HTTPException(status_code=500, detail="All providers failed")
model_handler = ModelRequestHandler()
@app.post("/v1/chat/completions")
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
return await model_handler.request_model(request, token)
@app.options("/v1/chat/completions")
async def options_handler():
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
@app.post("/v1/models")
async def list_models(token: str = Depends(verify_api_key)):
models = post_all_models(token)
return JSONResponse(content={
"object": "list",
"data": models
})
@app.get("/v1/models")
async def list_models():
models = get_all_models()
return JSONResponse(content={
"object": "list",
"data": models
})
@app.get("/generate-api-key")
def generate_api_key():
api_key = "sk-" + secrets.token_urlsafe(32)
return JSONResponse(content={"api_key": api_key})
# async def on_fetch(request, env):
# import asgi
# return await asgi.fetch(app, request, env)
if __name__ == '__main__':
import uvicorn
uvicorn.run(
"__main__:app",
host="0.0.0.0",
port=8000,
reload=True,
ws="none",
log_level="warning"
)