uni-api / utils.py
yym68686's picture
Fix error
cf121cf
raw
history blame
10.4 kB
import json
from fastapi import HTTPException
import httpx
from log_config import logger
def update_config(config_data):
for index, provider in enumerate(config_data['providers']):
model_dict = {}
for model in provider['model']:
if type(model) == str:
model_dict[model] = model
if type(model) == dict:
model_dict.update({new: old for old, new in model.items()})
provider['model'] = model_dict
config_data['providers'][index] = provider
api_keys_db = config_data['api_keys']
api_list = [item["api"] for item in api_keys_db]
# logger.info(json.dumps(config_data, indent=4, ensure_ascii=False))
return config_data, api_keys_db, api_list
# 读取YAML配置文件
async def load_config(app):
import yaml
try:
with open('./api.yaml', 'r') as f:
# 判断是否为空文件
conf = yaml.safe_load(f)
# conf = None
if conf:
config, api_keys_db, api_list = update_config(conf)
else:
# logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
config, api_keys_db, api_list = [], [], []
except FileNotFoundError:
logger.error("配置文件 'api.yaml' 未找到。请确保文件存在于正确的位置。")
config, api_keys_db, api_list = [], [], []
except yaml.YAMLError:
logger.error("配置文件 'api.yaml' 格式不正确。请检查 YAML 格式。")
config, api_keys_db, api_list = [], [], []
if config != []:
return config, api_keys_db, api_list
import os
# 新增: 从环境变量获取配置URL并拉取配置
config_url = os.environ.get('CONFIG_URL')
if config_url:
try:
response = await app.state.client.get(config_url)
# logger.info(f"Fetching config from {response.text}")
response.raise_for_status()
config_data = yaml.safe_load(response.text)
# 更新配置
# logger.info(config_data)
if config_data:
config, api_keys_db, api_list = update_config(config_data)
else:
logger.error(f"Error fetching or parsing config from {config_url}")
config, api_keys_db, api_list = [], [], []
except Exception as e:
logger.error(f"Error fetching or parsing config from {config_url}: {str(e)}")
config, api_keys_db, api_list = [], [], []
return config, api_keys_db, api_list
def ensure_string(item):
if isinstance(item, (bytes, bytearray)):
return item.decode("utf-8")
elif isinstance(item, str):
return item
elif isinstance(item, dict):
return f"data: {json.dumps(item)}\n\n"
else:
return str(item)
import asyncio
async def error_handling_wrapper(generator, status_code=200):
async def new_generator():
try:
yield ensure_string(first_item)
async for item in generator:
yield ensure_string(item)
except (httpx.ReadError, asyncio.CancelledError) as e:
logger.error(f"Network error in new_generator: {e}")
raise HTTPException(status_code=503, detail=f"Stream interrupted: {str(e)}")
except Exception as e:
logger.exception(f"Error in new_generator: {e}")
raise HTTPException(status_code=status_code, detail=f"Stream error: {str(e)}")
try:
first_item = await generator.__anext__()
first_item_str = first_item
if isinstance(first_item_str, (bytes, bytearray)):
first_item_str = first_item_str.decode("utf-8")
if isinstance(first_item_str, str):
if first_item_str.startswith("data: "):
first_item_str = first_item_str[6:]
elif first_item_str.startswith("data:"):
first_item_str = first_item_str[5:]
if first_item_str.startswith("[DONE]"):
logger.error("error_handling_wrapper [DONE]!")
raise StopAsyncIteration
try:
first_item_str = json.loads(first_item_str)
except json.JSONDecodeError:
logger.error("error_handling_wrapper JSONDecodeError!" + repr(first_item_str))
raise StopAsyncIteration
if isinstance(first_item_str, dict) and 'error' in first_item_str:
raise HTTPException(status_code=status_code, detail=f"{first_item_str}"[:300])
wrapped_generator = new_generator()
try:
async for item in wrapped_generator:
yield item
except HTTPException as http_exc:
raise HTTPException(status_code=status_code, detail=f"Wrapper error: {str(http_exc)}")
except (httpx.ReadError, asyncio.CancelledError) as e:
logger.error(f"Network error during streaming: {e}")
raise HTTPException(status_code=503, detail=f"Stream interrupted: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error in error_handling_wrapper: {e}")
raise HTTPException(status_code=status_code, detail=f"Unexpected error: {str(e)}")
except StopAsyncIteration:
raise HTTPException(status_code=status_code, detail="data: {'error': 'No data returned'}")
except HTTPException as http_exc:
raise HTTPException(status_code=status_code, detail=f"Wrapper error: {str(http_exc)}")
except Exception as e:
logger.exception(f"Error in error_handling_wrapper: {e}")
raise HTTPException(status_code=status_code, detail=f"Wrapper error: {str(e)}")
# async def error_handling_wrapper(generator, status_code=200):
# try:
# first_item = await generator.__anext__()
# first_item_str = first_item
# if isinstance(first_item_str, (bytes, bytearray)):
# first_item_str = first_item_str.decode("utf-8")
# if isinstance(first_item_str, str):
# if first_item_str.startswith("data: "):
# first_item_str = first_item_str[6:]
# elif first_item_str.startswith("data:"):
# first_item_str = first_item_str[5:]
# if first_item_str.startswith("[DONE]"):
# logger.error("error_handling_wrapper [DONE]!")
# raise StopAsyncIteration
# try:
# first_item_str = json.loads(first_item_str)
# except json.JSONDecodeError:
# logger.error("error_handling_wrapper JSONDecodeError!" + repr(first_item_str))
# raise StopAsyncIteration
# if isinstance(first_item_str, dict) and 'error' in first_item_str:
# # 如果第一个 yield 的项是错误信息,抛出 HTTPException
# raise HTTPException(status_code=status_code, detail=f"{first_item_str}"[:300])
# # 如果不是错误,创建一个新的生成器,首先yield第一个项,然后yield剩余的项
# async def new_generator():
# yield ensure_string(first_item)
# async for item in generator:
# yield ensure_string(item)
# return new_generator()
# except StopAsyncIteration:
# raise HTTPException(status_code=status_code, detail="data: {'error': 'No data returned'}")
def post_all_models(token, config, api_list):
all_models = []
unique_models = set()
if token not in api_list:
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
api_index = api_list.index(token)
if config['api_keys'][api_index]['model']:
for model in config['api_keys'][api_index]['model']:
if "/" in model:
provider = model.split("/")[0]
model = model.split("/")[1]
if model == "*":
for provider_item in config["providers"]:
if provider_item['provider'] != provider:
continue
for model_item in provider_item['model'].keys():
if model_item not in unique_models:
unique_models.add(model_item)
model_info = {
"id": model_item,
"object": "model",
"created": 1720524448858,
"owned_by": "uni-api"
# "owned_by": provider_item['provider']
}
all_models.append(model_info)
else:
for provider_item in config["providers"]:
if provider_item['provider'] != provider:
continue
for model_item in provider_item['model'].keys() :
if model_item not in unique_models and model_item == model:
unique_models.add(model_item)
model_info = {
"id": model_item,
"object": "model",
"created": 1720524448858,
"owned_by": "uni-api"
}
all_models.append(model_info)
continue
if model not in unique_models:
unique_models.add(model)
model_info = {
"id": model,
"object": "model",
"created": 1720524448858,
"owned_by": model
}
all_models.append(model_info)
return all_models
def get_all_models(config):
all_models = []
unique_models = set()
for provider in config["providers"]:
for model in provider['model'].keys():
if model not in unique_models:
unique_models.add(model)
model_info = {
"id": model,
"object": "model",
"created": 1720524448858,
"owned_by": "uni-api"
}
all_models.append(model_info)
return all_models