gemini-balancer / main.py
snailyp's picture
Update main.py
9126c93 verified
from fastapi import FastAPI, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import openai
from typing import List, Optional, Union
import logging
from itertools import cycle
import asyncio
import uvicorn
from app import config
import requests
from datetime import datetime, timezone
import json
import httpx
import uuid
import time
# 配置日志
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
# 允许跨域
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API密钥配置
API_KEYS = config.settings.API_KEYS
# 创建一个循环迭代器
key_cycle = cycle(API_KEYS)
# 创建两个独立的锁
key_cycle_lock = asyncio.Lock()
failure_count_lock = asyncio.Lock()
# 添加key失败计数记录
key_failure_counts = {key: 0 for key in API_KEYS}
MAX_FAILURES = 10 # 最大失败次数阈值
MAX_RETRIES = 3 # 最大重试次数
async def get_next_key():
"""仅获取下一个key,不检查失败次数"""
async with key_cycle_lock:
return next(key_cycle)
async def is_key_valid(key):
"""检查key是否有效"""
async with failure_count_lock:
return key_failure_counts[key] < MAX_FAILURES
async def reset_failure_counts():
"""重置所有key的失败计数"""
async with failure_count_lock:
for key in key_failure_counts:
key_failure_counts[key] = 0
async def get_next_working_key():
"""获取下一个可用的API key"""
initial_key = await get_next_key()
current_key = initial_key
while True:
if await is_key_valid(current_key):
return current_key
current_key = await get_next_key()
if current_key == initial_key: # 已经循环了一圈
await reset_failure_counts()
return current_key
async def handle_api_failure(api_key):
"""处理API调用失败"""
async with failure_count_lock:
key_failure_counts[api_key] += 1
if key_failure_counts[api_key] >= MAX_FAILURES:
logger.warning(f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key")
# 在锁外获取新的key
return await get_next_working_key()
class ChatRequest(BaseModel):
messages: List[dict]
model: str = "gemini-1.5-flash-002"
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
tools: Optional[List[dict]] = []
tool_choice: Optional[str] = "auto"
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-004"
encoding_format: Optional[str] = "float"
async def verify_authorization(authorization: str = Header(None)):
if not authorization:
logger.error("Missing Authorization header")
raise HTTPException(status_code=401, detail="Missing Authorization header")
if not authorization.startswith("Bearer "):
logger.error("Invalid Authorization header format")
raise HTTPException(
status_code=401, detail="Invalid Authorization header format"
)
token = authorization.replace("Bearer ", "")
if token not in config.settings.ALLOWED_TOKENS:
logger.error("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
return token
def get_gemini_models(api_key):
base_url = "https://generativelanguage.googleapis.com/v1beta"
url = f"{base_url}/models?key={api_key}"
try:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
return convert_to_openai_models_format(gemini_models)
else:
print(f"Error: {response.status_code}")
print(response.text)
return None
except requests.RequestException as e:
print(f"Request failed: {e}")
return None
def convert_to_openai_models_format(gemini_models):
openai_format = {"object": "list", "data": []}
for model in gemini_models.get("models", []):
openai_model = {
"id": model["name"].split("/")[-1], # 取最后一部分作为ID
"object": "model",
"created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
"owned_by": "google", # 假设所有Gemini模型都由Google拥有
"permission": [], # Gemini API可能没有直接对应的权限信息
"root": model["name"],
"parent": None, # Gemini API可能没有直接对应的父模型信息
}
openai_format["data"].append(openai_model)
return openai_format
def convert_messages_to_gemini_format(messages):
"""Convert OpenAI message format to Gemini format"""
gemini_messages = []
for message in messages:
gemini_message = {
"role": "user" if message["role"] == "user" else "model",
"parts": [{"text": message["content"]}],
}
gemini_messages.append(gemini_message)
return gemini_messages
def convert_gemini_response_to_openai(response, model, stream=False):
"""Convert Gemini response to OpenAI format"""
if stream:
# 处理流式响应
chunk = response
if not chunk["candidates"]:
return None
return {
"id": "chatcmpl-" + str(uuid.uuid4()),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": chunk["candidates"][0]["content"]["parts"][0]["text"]
},
"finish_reason": None,
}
],
}
else:
# 处理普通响应
return {
"id": "chatcmpl-" + str(uuid.uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response["candidates"][0]["content"]["parts"][0][
"text"
],
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
@app.get("/v1/models")
@app.get("/hf/v1/models")
async def list_models(authorization: str = Header(None)):
await verify_authorization(authorization)
api_key = await get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
response = get_gemini_models(api_key)
logger.info("Successfully retrieved models list")
return response
except Exception as e:
logger.error(f"Error listing models: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/chat/completions")
@app.post("/hf/v1/chat/completions")
async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
await verify_authorization(authorization)
api_key = await get_next_working_key()
logger.info(f"Chat completion request - Model: {request.model}")
retries = 0
while retries < MAX_RETRIES:
try:
logger.info(f"Attempt {retries + 1} with API key: {api_key}")
if request.model in config.settings.MODEL_SEARCH:
# Gemini API调用部分
gemini_messages = convert_messages_to_gemini_format(request.messages)
# 调用Gemini API
payload = {
"contents": gemini_messages,
"generationConfig": {
"temperature": request.temperature,
},
"tools": [{"googleSearch": {}}],
}
if request.stream:
logger.info("Streaming response enabled")
async def generate():
nonlocal api_key, retries
while retries < MAX_RETRIES:
try:
async with httpx.AsyncClient() as client:
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream("POST", stream_url, json=payload) as response:
if response.status_code == 429:
logger.warning(f"Rate limit reached for key: {api_key}")
api_key = await handle_api_failure(api_key)
logger.info(f"Retrying with new API key: {api_key}")
retries += 1
if retries >= MAX_RETRIES:
yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
break
continue
if response.status_code != 200:
logger.error(f"Error in streaming response: {response.status_code}")
yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
break
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
chunk = json.loads(line[6:])
openai_chunk = convert_gemini_response_to_openai(
chunk, request.model, stream=True
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
except json.JSONDecodeError:
continue
yield "data: [DONE]\n\n"
return
except Exception as e:
logger.error(f"Stream error: {str(e)}")
api_key = await handle_api_failure(api_key)
retries += 1
if retries >= MAX_RETRIES:
yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
break
continue
return StreamingResponse(content=generate(), media_type="text/event-stream")
else:
# 非流式响应
async with httpx.AsyncClient() as client:
non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
response = await client.post(non_stream_url, json=payload)
gemini_response = response.json()
logger.info("Chat completion successful")
return convert_gemini_response_to_openai(gemini_response, request.model)
# OpenAI API调用部分
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
response = client.chat.completions.create(
model=request.model,
messages=request.messages,
temperature=request.temperature,
stream=request.stream if hasattr(request, "stream") else False,
)
if hasattr(request, "stream") and request.stream:
logger.info("Streaming response enabled")
async def generate():
for chunk in response:
yield f"data: {chunk.model_dump_json()}\n\n"
logger.info("Chat completion successful")
return StreamingResponse(content=generate(), media_type="text/event-stream")
logger.info("Chat completion successful")
return response
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}")
api_key = await handle_api_failure(api_key)
retries += 1
if retries >= MAX_RETRIES:
logger.error("Max retries reached, giving up")
raise HTTPException(status_code=500, detail="Max retries reached with all available API keys")
logger.info(f"Retrying with new API key: {api_key}")
continue
raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
@app.post("/v1/embeddings")
@app.post("/hf/v1/embeddings")
async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
await verify_authorization(authorization)
api_key = await get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
response = client.embeddings.create(input=request.input, model=request.model)
logger.info("Embedding successful")
return response
except Exception as e:
logger.error(f"Error in embedding: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
@app.get("/")
async def health_check():
logger.info("Health check endpoint called")
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)