Spaces:
Sleeping
Sleeping
import sys, os, platform, time, copy, re, asyncio, inspect | |
import threading, ast | |
import shutil, random, traceback, requests | |
from datetime import datetime, timedelta, timezone | |
from typing import Optional, List | |
import secrets, subprocess | |
import hashlib, uuid | |
import warnings | |
import importlib | |
messages: list = [] | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system path - for litellm local dev | |
sample = """ | |
from openai import OpenAI | |
import json | |
base_url = "https://ka1kuk-litellm.hf.space" | |
api_key = "hf_xxxx" | |
client = OpenAI(base_url=base_url, api_key=api_key) | |
messages = [{"role": "user", "content": "What's the capital of France?"}] | |
response = client.chat.completions.create( | |
model="huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1", | |
response_format={ "type": "json_object" }, | |
messages=messages, | |
stream=False, | |
) | |
print(response.choices[0].message.content) | |
""" | |
description = f"Proxy Server to call 100+ LLMs in the OpenAI format\n\nSample with openai library:\n\n{sample}" | |
try: | |
import fastapi | |
import backoff | |
import yaml | |
import orjson | |
import logging | |
except ImportError as e: | |
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`") | |
import litellm | |
from litellm.proxy.utils import ( | |
PrismaClient, | |
DBClient, | |
get_instance_fn, | |
ProxyLogging, | |
_cache_user_row, | |
send_email, | |
) | |
from litellm.proxy.secret_managers.google_kms import load_google_kms | |
import pydantic | |
from litellm.proxy._types import * | |
from litellm.caching import DualCache | |
from litellm.proxy.health_check import perform_health_check | |
from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
litellm.suppress_debug_info = True | |
from fastapi import ( | |
FastAPI, | |
Request, | |
HTTPException, | |
status, | |
Depends, | |
BackgroundTasks, | |
Header, | |
Response, | |
) | |
from fastapi.routing import APIRouter | |
from fastapi.security import OAuth2PasswordBearer | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.security.api_key import APIKeyHeader | |
import json | |
import logging | |
from typing import Union | |
app = FastAPI( | |
docs_url="/", | |
title="LiteLLM API", | |
description= description, | |
) | |
router = APIRouter() | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
from typing import Dict | |
api_key_header = APIKeyHeader(name="Authorization", auto_error=False) | |
user_api_base = None | |
user_model = None | |
user_debug = False | |
user_max_tokens = None | |
user_request_timeout = None | |
user_temperature = None | |
user_telemetry = True | |
user_config = None | |
user_headers = None | |
user_config_file_path = f"config_{int(time.time())}.yaml" | |
local_logging = True # writes logs to a local api_log.json file for debugging | |
experimental = False | |
#### GLOBAL VARIABLES #### | |
llm_router: Optional[litellm.Router] = None | |
llm_model_list: Optional[list] = None | |
general_settings: dict = {} | |
log_file = "api_log.json" | |
worker_config = None | |
master_key = None | |
otel_logging = False | |
prisma_client: Optional[PrismaClient] = None | |
custom_db_client: Optional[DBClient] = None | |
user_api_key_cache = DualCache() | |
user_custom_auth = None | |
use_background_health_checks = None | |
use_queue = False | |
health_check_interval = None | |
health_check_results = {} | |
queue: List = [] | |
### INITIALIZE GLOBAL LOGGING OBJECT ### | |
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) | |
### REDIS QUEUE ### | |
async_result = None | |
celery_app_conn = None | |
celery_fn = None # Redis Queue for handling requests | |
### logger ### | |
def usage_telemetry( | |
feature: str, | |
): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off | |
if user_telemetry: | |
data = {"feature": feature} # "local_proxy_server" | |
threading.Thread( | |
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True | |
).start() | |
def _get_bearer_token(api_key: str): | |
assert api_key.startswith("Bearer ") # ensure Bearer token passed in | |
api_key = api_key.replace("Bearer ", "") # extract the token | |
return api_key | |
def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: | |
try: | |
return pydantic_obj.model_dump() # type: ignore | |
except: | |
# if using pydantic v1 | |
return pydantic_obj.dict() | |
async def user_api_key_auth( | |
request: Request, api_key: str = fastapi.Security(api_key_header) | |
) -> UserAPIKeyAuth: | |
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client | |
try: | |
if isinstance(api_key, str): | |
api_key = _get_bearer_token(api_key=api_key) | |
### USER-DEFINED AUTH FUNCTION ### | |
if user_custom_auth is not None: | |
response = await user_custom_auth(request=request, api_key=api_key) | |
return UserAPIKeyAuth.model_validate(response) | |
### LITELLM-DEFINED AUTH FUNCTION ### | |
if master_key is None: | |
if isinstance(api_key, str): | |
return UserAPIKeyAuth(api_key=api_key) | |
else: | |
return UserAPIKeyAuth() | |
route: str = request.url.path | |
if route == "/user/auth": | |
if general_settings.get("allow_user_auth", False) == True: | |
return UserAPIKeyAuth() | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_403_FORBIDDEN, | |
detail="'allow_user_auth' not set or set to False", | |
) | |
if api_key is None: # only require api key if master key is set | |
raise Exception(f"No api key passed in.") | |
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead | |
is_master_key_valid = secrets.compare_digest(api_key, master_key) | |
if is_master_key_valid: | |
return UserAPIKeyAuth(api_key=master_key) | |
if route.startswith("/config/") and not is_master_key_valid: | |
raise Exception(f"Only admin can modify config") | |
if ( | |
(route.startswith("/key/") or route.startswith("/user/")) | |
or route.startswith("/model/") | |
and not is_master_key_valid | |
and general_settings.get("allow_user_auth", False) != True | |
): | |
raise Exception( | |
f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users" | |
) | |
if ( | |
prisma_client is None and custom_db_client is None | |
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error | |
raise Exception("No connected db.") | |
## check for cache hit (In-Memory Cache) | |
valid_token = user_api_key_cache.get_cache(key=api_key) | |
verbose_proxy_logger.debug(f"valid_token from cache: {valid_token}") | |
if valid_token is None: | |
## check db | |
verbose_proxy_logger.debug(f"api key: {api_key}") | |
if prisma_client is not None: | |
valid_token = await prisma_client.get_data( | |
token=api_key, | |
) | |
expires = datetime.utcnow().replace(tzinfo=timezone.utc) | |
elif custom_db_client is not None: | |
valid_token = await custom_db_client.get_data( | |
key=api_key, table_name="key" | |
) | |
# Token exists, now check expiration. | |
if valid_token.expires is not None: | |
expiry_time = datetime.fromisoformat(valid_token.expires) | |
if expiry_time >= datetime.utcnow(): | |
# Token exists and is not expired. | |
return response | |
else: | |
# Token exists but is expired. | |
raise HTTPException( | |
status_code=status.HTTP_403_FORBIDDEN, | |
detail="expired user key", | |
) | |
verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}") | |
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) | |
elif valid_token is not None: | |
verbose_proxy_logger.debug(f"API Key Cache Hit!") | |
if valid_token: | |
litellm.model_alias_map = valid_token.aliases | |
config = valid_token.config | |
if config != {}: | |
model_list = config.get("model_list", []) | |
llm_model_list = model_list | |
verbose_proxy_logger.debug( | |
f"\n new llm router model list {llm_model_list}" | |
) | |
if ( | |
len(valid_token.models) == 0 | |
): # assume an empty model list means all models are allowed to be called | |
pass | |
else: | |
try: | |
data = await request.json() | |
except json.JSONDecodeError: | |
data = {} # Provide a default value, such as an empty dictionary | |
model = data.get("model", None) | |
if model in litellm.model_alias_map: | |
model = litellm.model_alias_map[model] | |
if model and model not in valid_token.models: | |
raise Exception(f"Token not allowed to access model") | |
api_key = valid_token.token | |
valid_token_dict = _get_pydantic_json_dict(valid_token) | |
valid_token_dict.pop("token", None) | |
""" | |
asyncio create task to update the user api key cache with the user db table as well | |
This makes the user row data accessible to pre-api call hooks. | |
""" | |
if prisma_client is not None: | |
asyncio.create_task( | |
_cache_user_row( | |
user_id=valid_token.user_id, | |
cache=user_api_key_cache, | |
db=prisma_client, | |
) | |
) | |
elif custom_db_client is not None: | |
asyncio.create_task( | |
_cache_user_row( | |
user_id=valid_token.user_id, | |
cache=user_api_key_cache, | |
db=custom_db_client, | |
) | |
) | |
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) | |
else: | |
raise Exception(f"Invalid token") | |
except Exception as e: | |
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}") | |
traceback.print_exc() | |
if isinstance(e, HTTPException): | |
raise e | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="invalid user key", | |
) | |
def prisma_setup(database_url: Optional[str]): | |
global prisma_client, proxy_logging_obj, user_api_key_cache | |
if database_url is not None: | |
try: | |
prisma_client = PrismaClient( | |
database_url=database_url, proxy_logging_obj=proxy_logging_obj | |
) | |
except Exception as e: | |
raise e | |
def load_from_azure_key_vault(use_azure_key_vault: bool = False): | |
if use_azure_key_vault is False: | |
return | |
try: | |
from azure.keyvault.secrets import SecretClient | |
from azure.identity import ClientSecretCredential | |
# Set your Azure Key Vault URI | |
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None) | |
# Set your Azure AD application/client ID, client secret, and tenant ID | |
client_id = os.getenv("AZURE_CLIENT_ID", None) | |
client_secret = os.getenv("AZURE_CLIENT_SECRET", None) | |
tenant_id = os.getenv("AZURE_TENANT_ID", None) | |
if ( | |
KVUri is not None | |
and client_id is not None | |
and client_secret is not None | |
and tenant_id is not None | |
): | |
# Initialize the ClientSecretCredential | |
credential = ClientSecretCredential( | |
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id | |
) | |
# Create the SecretClient using the credential | |
client = SecretClient(vault_url=KVUri, credential=credential) | |
litellm.secret_manager_client = client | |
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT | |
else: | |
raise Exception( | |
f"Missing KVUri or client_id or client_secret or tenant_id from environment" | |
) | |
except Exception as e: | |
verbose_proxy_logger.debug( | |
"Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`" | |
) | |
def cost_tracking(): | |
global prisma_client, custom_db_client | |
if prisma_client is not None or custom_db_client is not None: | |
if isinstance(litellm.success_callback, list): | |
verbose_proxy_logger.debug("setting litellm success callback to track cost") | |
if (track_cost_callback) not in litellm.success_callback: # type: ignore | |
litellm.success_callback.append(track_cost_callback) # type: ignore | |
async def track_cost_callback( | |
kwargs, # kwargs to completion | |
completion_response: litellm.ModelResponse, # response from completion | |
start_time=None, | |
end_time=None, # start/end time for completion | |
): | |
global prisma_client, custom_db_client | |
try: | |
# check if it has collected an entire stream response | |
verbose_proxy_logger.debug( | |
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" | |
) | |
if "complete_streaming_response" in kwargs: | |
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost | |
completion_response = kwargs["complete_streaming_response"] | |
response_cost = litellm.completion_cost( | |
completion_response=completion_response | |
) | |
verbose_proxy_logger.debug(f"streaming response_cost {response_cost}") | |
user_api_key = kwargs["litellm_params"]["metadata"].get( | |
"user_api_key", None | |
) | |
user_id = kwargs["litellm_params"]["metadata"].get( | |
"user_api_key_user_id", None | |
) | |
if user_api_key and ( | |
prisma_client is not None or custom_db_client is not None | |
): | |
await update_database(token=user_api_key, response_cost=response_cost) | |
elif kwargs["stream"] == False: # for non streaming responses | |
response_cost = litellm.completion_cost( | |
completion_response=completion_response | |
) | |
user_api_key = kwargs["litellm_params"]["metadata"].get( | |
"user_api_key", None | |
) | |
user_id = kwargs["litellm_params"]["metadata"].get( | |
"user_api_key_user_id", None | |
) | |
if user_api_key and ( | |
prisma_client is not None or custom_db_client is not None | |
): | |
await update_database( | |
token=user_api_key, response_cost=response_cost, user_id=user_id | |
) | |
except Exception as e: | |
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") | |
async def update_database(token, response_cost, user_id=None): | |
try: | |
verbose_proxy_logger.debug( | |
f"Enters prisma db call, token: {token}; user_id: {user_id}" | |
) | |
### UPDATE USER SPEND ### | |
async def _update_user_db(): | |
if user_id is None: | |
return | |
if prisma_client is not None: | |
existing_spend_obj = await prisma_client.get_data(user_id=user_id) | |
elif custom_db_client is not None: | |
existing_spend_obj = await custom_db_client.get_data( | |
key=user_id, table_name="user" | |
) | |
if existing_spend_obj is None: | |
existing_spend = 0 | |
else: | |
existing_spend = existing_spend_obj.spend | |
# Calculate the new cost by adding the existing cost and response_cost | |
new_spend = existing_spend + response_cost | |
verbose_proxy_logger.debug(f"new cost: {new_spend}") | |
# Update the cost column for the given user id | |
if prisma_client is not None: | |
await prisma_client.update_data( | |
user_id=user_id, data={"spend": new_spend} | |
) | |
elif custom_db_client is not None: | |
await custom_db_client.update_data( | |
key=user_id, value={"spend": new_spend}, table_name="user" | |
) | |
### UPDATE KEY SPEND ### | |
async def _update_key_db(): | |
if prisma_client is not None: | |
# Fetch the existing cost for the given token | |
existing_spend_obj = await prisma_client.get_data(token=token) | |
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") | |
if existing_spend_obj is None: | |
existing_spend = 0 | |
else: | |
existing_spend = existing_spend_obj.spend | |
# Calculate the new cost by adding the existing cost and response_cost | |
new_spend = existing_spend + response_cost | |
verbose_proxy_logger.debug(f"new cost: {new_spend}") | |
# Update the cost column for the given token | |
await prisma_client.update_data(token=token, data={"spend": new_spend}) | |
elif custom_db_client is not None: | |
# Fetch the existing cost for the given token | |
existing_spend_obj = await custom_db_client.get_data( | |
key=token, table_name="key" | |
) | |
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") | |
if existing_spend_obj is None: | |
existing_spend = 0 | |
else: | |
existing_spend = existing_spend_obj.spend | |
# Calculate the new cost by adding the existing cost and response_cost | |
new_spend = existing_spend + response_cost | |
verbose_proxy_logger.debug(f"new cost: {new_spend}") | |
# Update the cost column for the given token | |
await custom_db_client.update_data( | |
key=token, value={"spend": new_spend}, table_name="key" | |
) | |
tasks = [] | |
tasks.append(_update_user_db()) | |
tasks.append(_update_key_db()) | |
await asyncio.gather(*tasks) | |
except Exception as e: | |
verbose_proxy_logger.debug( | |
f"Error updating Prisma database: {traceback.format_exc()}" | |
) | |
pass | |
def run_ollama_serve(): | |
try: | |
command = ["ollama", "serve"] | |
with open(os.devnull, "w") as devnull: | |
process = subprocess.Popen(command, stdout=devnull, stderr=devnull) | |
except Exception as e: | |
verbose_proxy_logger.debug( | |
f""" | |
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` | |
""" | |
) | |
async def _run_background_health_check(): | |
""" | |
Periodically run health checks in the background on the endpoints. | |
Update health_check_results, based on this. | |
""" | |
global health_check_results, llm_model_list, health_check_interval | |
while True: | |
healthy_endpoints, unhealthy_endpoints = await perform_health_check( | |
model_list=llm_model_list | |
) | |
# Update the global variable with the health check results | |
health_check_results["healthy_endpoints"] = healthy_endpoints | |
health_check_results["unhealthy_endpoints"] = unhealthy_endpoints | |
health_check_results["healthy_count"] = len(healthy_endpoints) | |
health_check_results["unhealthy_count"] = len(unhealthy_endpoints) | |
await asyncio.sleep(health_check_interval) | |
class ProxyConfig: | |
""" | |
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic. | |
""" | |
def __init__(self) -> None: | |
pass | |
def is_yaml(self, config_file_path: str) -> bool: | |
if not os.path.isfile(config_file_path): | |
return False | |
_, file_extension = os.path.splitext(config_file_path) | |
return file_extension.lower() == ".yaml" or file_extension.lower() == ".yml" | |
async def get_config(self, config_file_path: Optional[str] = None) -> dict: | |
global prisma_client, user_config_file_path | |
file_path = config_file_path or user_config_file_path | |
if config_file_path is not None: | |
user_config_file_path = config_file_path | |
# Load existing config | |
## Yaml | |
if os.path.exists(f"{file_path}"): | |
with open(f"{file_path}", "r") as config_file: | |
config = yaml.safe_load(config_file) | |
else: | |
config = { | |
"model_list": [], | |
"general_settings": {}, | |
"router_settings": {}, | |
"litellm_settings": {}, | |
} | |
## DB | |
if ( | |
prisma_client is not None | |
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True | |
): | |
prisma_setup(database_url=None) # in case it's not been connected yet | |
_tasks = [] | |
keys = [ | |
"model_list", | |
"general_settings", | |
"router_settings", | |
"litellm_settings", | |
] | |
for k in keys: | |
response = prisma_client.get_generic_data( | |
key="param_name", value=k, table_name="config" | |
) | |
_tasks.append(response) | |
responses = await asyncio.gather(*_tasks) | |
return config | |
async def save_config(self, new_config: dict): | |
global prisma_client, llm_router, user_config_file_path, llm_model_list, general_settings | |
# Load existing config | |
backup_config = await self.get_config() | |
# Save the updated config | |
## YAML | |
with open(f"{user_config_file_path}", "w") as config_file: | |
yaml.dump(new_config, config_file, default_flow_style=False) | |
# update Router - verifies if this is a valid config | |
try: | |
( | |
llm_router, | |
llm_model_list, | |
general_settings, | |
) = await proxy_config.load_config( | |
router=llm_router, config_file_path=user_config_file_path | |
) | |
except Exception as e: | |
traceback.print_exc() | |
# Revert to old config instead | |
with open(f"{user_config_file_path}", "w") as config_file: | |
yaml.dump(backup_config, config_file, default_flow_style=False) | |
raise HTTPException(status_code=400, detail="Invalid config passed in") | |
## DB - writes valid config to db | |
""" | |
- Do not write restricted params like 'api_key' to the database | |
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`) | |
""" | |
if ( | |
prisma_client is not None | |
and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True | |
): | |
### KEY REMOVAL ### | |
models = new_config.get("model_list", []) | |
for m in models: | |
if m.get("litellm_params", {}).get("api_key", None) is not None: | |
# pop the key | |
api_key = m["litellm_params"].pop("api_key") | |
# store in local env | |
key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}" | |
os.environ[key_name] = api_key | |
# save the key name (not the value) | |
m["litellm_params"]["api_key"] = f"os.environ/{key_name}" | |
await prisma_client.insert_data(data=new_config, table_name="config") | |
async def load_config( | |
self, router: Optional[litellm.Router], config_file_path: str | |
): | |
""" | |
Load config values into proxy global state | |
""" | |
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client | |
# Load existing config | |
config = await self.get_config(config_file_path=config_file_path) | |
## PRINT YAML FOR CONFIRMING IT WORKS | |
printed_yaml = copy.deepcopy(config) | |
printed_yaml.pop("environment_variables", None) | |
verbose_proxy_logger.debug( | |
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}" | |
) | |
## ENVIRONMENT VARIABLES | |
environment_variables = config.get("environment_variables", None) | |
if environment_variables: | |
for key, value in environment_variables.items(): | |
os.environ[key] = value | |
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) | |
litellm_settings = config.get("litellm_settings", None) | |
if litellm_settings is None: | |
litellm_settings = {} | |
if litellm_settings: | |
# ANSI escape code for blue text | |
blue_color_code = "\033[94m" | |
reset_color_code = "\033[0m" | |
for key, value in litellm_settings.items(): | |
if key == "cache": | |
print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa | |
from litellm.caching import Cache | |
cache_params = {} | |
if "cache_params" in litellm_settings: | |
cache_params_in_config = litellm_settings["cache_params"] | |
# overwrie cache_params with cache_params_in_config | |
cache_params.update(cache_params_in_config) | |
cache_type = cache_params.get("type", "redis") | |
verbose_proxy_logger.debug(f"passed cache type={cache_type}") | |
if cache_type == "redis": | |
cache_host = litellm.get_secret("REDIS_HOST", None) | |
cache_port = litellm.get_secret("REDIS_PORT", None) | |
cache_password = litellm.get_secret("REDIS_PASSWORD", None) | |
cache_params.update( | |
{ | |
"type": cache_type, | |
"host": cache_host, | |
"port": cache_port, | |
"password": cache_password, | |
} | |
) | |
# Assuming cache_type, cache_host, cache_port, and cache_password are strings | |
print( # noqa | |
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}" | |
) # noqa | |
print( # noqa | |
f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}" | |
) # noqa | |
print( # noqa | |
f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}" | |
) # noqa | |
print( # noqa | |
f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}" | |
) | |
print() # noqa | |
# users can pass os.environ/ variables on the proxy - we should read them from the env | |
for key, value in cache_params.items(): | |
if type(value) is str and value.startswith("os.environ/"): | |
cache_params[key] = litellm.get_secret(value) | |
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables | |
litellm.cache = Cache(**cache_params) | |
print( # noqa | |
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" | |
) | |
elif key == "callbacks": | |
litellm.callbacks = [ | |
get_instance_fn(value=value, config_file_path=config_file_path) | |
] | |
verbose_proxy_logger.debug( | |
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" | |
) | |
elif key == "post_call_rules": | |
litellm.post_call_rules = [ | |
get_instance_fn(value=value, config_file_path=config_file_path) | |
] | |
verbose_proxy_logger.debug( | |
f"litellm.post_call_rules: {litellm.post_call_rules}" | |
) | |
elif key == "success_callback": | |
litellm.success_callback = [] | |
# intialize success callbacks | |
for callback in value: | |
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function | |
if "." in callback: | |
litellm.success_callback.append( | |
get_instance_fn(value=callback) | |
) | |
# these are litellm callbacks - "langfuse", "sentry", "wandb" | |
else: | |
litellm.success_callback.append(callback) | |
verbose_proxy_logger.debug( | |
f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}" | |
) | |
elif key == "failure_callback": | |
litellm.failure_callback = [] | |
# intialize success callbacks | |
for callback in value: | |
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function | |
if "." in callback: | |
litellm.failure_callback.append( | |
get_instance_fn(value=callback) | |
) | |
# these are litellm callbacks - "langfuse", "sentry", "wandb" | |
else: | |
litellm.failure_callback.append(callback) | |
verbose_proxy_logger.debug( | |
f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}" | |
) | |
elif key == "cache_params": | |
# this is set in the cache branch | |
# see usage here: https://docs.litellm.ai/docs/proxy/caching | |
pass | |
else: | |
setattr(litellm, key, value) | |
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging | |
general_settings = config.get("general_settings", {}) | |
if general_settings is None: | |
general_settings = {} | |
if general_settings: | |
### LOAD SECRET MANAGER ### | |
key_management_system = general_settings.get("key_management_system", None) | |
if key_management_system is not None: | |
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value: | |
### LOAD FROM AZURE KEY VAULT ### | |
load_from_azure_key_vault(use_azure_key_vault=True) | |
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: | |
### LOAD FROM GOOGLE KMS ### | |
load_google_kms(use_google_kms=True) | |
else: | |
raise ValueError("Invalid Key Management System selected") | |
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms | |
use_google_kms = general_settings.get("use_google_kms", False) | |
load_google_kms(use_google_kms=use_google_kms) | |
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager | |
use_azure_key_vault = general_settings.get("use_azure_key_vault", False) | |
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) | |
### ALERTING ### | |
proxy_logging_obj.update_values( | |
alerting=general_settings.get("alerting", None), | |
alerting_threshold=general_settings.get("alerting_threshold", 600), | |
) | |
### CONNECT TO DATABASE ### | |
database_url = general_settings.get("database_url", None) | |
if database_url and database_url.startswith("os.environ/"): | |
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!") | |
database_url = litellm.get_secret(database_url) | |
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}") | |
### MASTER KEY ### | |
master_key = general_settings.get( | |
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) | |
) | |
if master_key and master_key.startswith("os.environ/"): | |
master_key = litellm.get_secret(master_key) | |
### CUSTOM API KEY AUTH ### | |
## pass filepath | |
custom_auth = general_settings.get("custom_auth", None) | |
if custom_auth is not None: | |
user_custom_auth = get_instance_fn( | |
value=custom_auth, config_file_path=config_file_path | |
) | |
## dynamodb | |
database_type = general_settings.get("database_type", None) | |
if database_type is not None and ( | |
database_type == "dynamo_db" or database_type == "dynamodb" | |
): | |
database_args = general_settings.get("database_args", None) | |
custom_db_client = DBClient( | |
custom_db_args=database_args, custom_db_type=database_type | |
) | |
## COST TRACKING ## | |
cost_tracking() | |
### BACKGROUND HEALTH CHECKS ### | |
# Enable background health checks | |
use_background_health_checks = general_settings.get( | |
"background_health_checks", False | |
) | |
health_check_interval = general_settings.get("health_check_interval", 300) | |
router_params: dict = { | |
"num_retries": 3, | |
"cache_responses": litellm.cache | |
!= None, # cache if user passed in cache values | |
} | |
## MODEL LIST | |
model_list = config.get("model_list", None) | |
if model_list: | |
router_params["model_list"] = model_list | |
print( # noqa | |
f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m" | |
) # noqa | |
for model in model_list: | |
### LOAD FROM os.environ/ ### | |
for k, v in model["litellm_params"].items(): | |
if isinstance(v, str) and v.startswith("os.environ/"): | |
model["litellm_params"][k] = litellm.get_secret(v) | |
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa | |
litellm_model_name = model["litellm_params"]["model"] | |
litellm_model_api_base = model["litellm_params"].get("api_base", None) | |
if "ollama" in litellm_model_name and litellm_model_api_base is None: | |
run_ollama_serve() | |
## ROUTER SETTINGS (e.g. routing_strategy, ...) | |
router_settings = config.get("router_settings", None) | |
if router_settings and isinstance(router_settings, dict): | |
arg_spec = inspect.getfullargspec(litellm.Router) | |
# model list already set | |
exclude_args = { | |
"self", | |
"model_list", | |
} | |
available_args = [x for x in arg_spec.args if x not in exclude_args] | |
for k, v in router_settings.items(): | |
if k in available_args: | |
router_params[k] = v | |
router = litellm.Router(**router_params) # type:ignore | |
return router, model_list, general_settings | |
proxy_config = ProxyConfig() | |
async def generate_key_helper_fn( | |
duration: Optional[str], | |
models: list, | |
aliases: dict, | |
config: dict, | |
spend: float, | |
max_budget: Optional[float] = None, | |
token: Optional[str] = None, | |
user_id: Optional[str] = None, | |
user_email: Optional[str] = None, | |
max_parallel_requests: Optional[int] = None, | |
metadata: Optional[dict] = {}, | |
): | |
global prisma_client, custom_db_client | |
if prisma_client is None and custom_db_client is None: | |
raise Exception( | |
f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " | |
) | |
if token is None: | |
token = f"sk-{secrets.token_urlsafe(16)}" | |
def _duration_in_seconds(duration: str): | |
match = re.match(r"(\d+)([smhd]?)", duration) | |
if not match: | |
raise ValueError("Invalid duration format") | |
value, unit = match.groups() | |
value = int(value) | |
if unit == "s": | |
return value | |
elif unit == "m": | |
return value * 60 | |
elif unit == "h": | |
return value * 3600 | |
elif unit == "d": | |
return value * 86400 | |
else: | |
raise ValueError("Unsupported duration unit") | |
if duration is None: # allow tokens that never expire | |
expires = None | |
else: | |
duration_s = _duration_in_seconds(duration=duration) | |
expires = datetime.utcnow() + timedelta(seconds=duration_s) | |
aliases_json = json.dumps(aliases) | |
config_json = json.dumps(config) | |
metadata_json = json.dumps(metadata) | |
user_id = user_id or str(uuid.uuid4()) | |
try: | |
# Create a new verification token (you may want to enhance this logic based on your needs) | |
user_data = { | |
"max_budget": max_budget, | |
"user_email": user_email, | |
"user_id": user_id, | |
"spend": spend, | |
} | |
key_data = { | |
"token": token, | |
"expires": expires, | |
"models": models, | |
"aliases": aliases_json, | |
"config": config_json, | |
"spend": spend, | |
"user_id": user_id, | |
"max_parallel_requests": max_parallel_requests, | |
"metadata": metadata_json, | |
} | |
if prisma_client is not None: | |
verification_token_data = dict(key_data) | |
verification_token_data.update(user_data) | |
verbose_proxy_logger.debug("PrismaClient: Before Insert Data") | |
await prisma_client.insert_data(data=verification_token_data) | |
elif custom_db_client is not None: | |
## CREATE USER (If necessary) | |
await custom_db_client.insert_data(value=user_data, table_name="user") | |
## CREATE KEY | |
await custom_db_client.insert_data(value=key_data, table_name="key") | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) | |
return { | |
"token": token, | |
"expires": expires, | |
"user_id": user_id, | |
"max_budget": max_budget, | |
} | |
async def delete_verification_token(tokens: List): | |
global prisma_client | |
try: | |
if prisma_client: | |
# Assuming 'db' is your Prisma Client instance | |
deleted_tokens = await prisma_client.delete_data(tokens=tokens) | |
else: | |
raise Exception | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) | |
return deleted_tokens | |
def save_worker_config(**data): | |
import json | |
os.environ["WORKER_CONFIG"] = json.dumps(data) | |
async def initialize( | |
model=None, | |
alias=None, | |
api_base=None, | |
api_version=None, | |
debug=False, | |
detailed_debug=False, | |
temperature=None, | |
max_tokens=None, | |
request_timeout=600, | |
max_budget=None, | |
telemetry=False, | |
drop_params=True, | |
add_function_to_prompt=True, | |
headers=None, | |
save=False, | |
use_queue=False, | |
config=None, | |
): | |
global user_model, user_api_base, user_debug, user_detailed_debug, user_user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client | |
user_model = model | |
user_debug = debug | |
if debug == True: # this needs to be first, so users can see Router init debugg | |
from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
import logging | |
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS | |
verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info | |
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info | |
if detailed_debug == True: | |
from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
import logging | |
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to info | |
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug | |
litellm.set_verbose = True | |
elif debug == False and detailed_debug == False: | |
# users can control proxy debugging using env variable = 'LITELLM_LOG' | |
litellm_log_setting = os.environ.get("LITELLM_LOG", "") | |
if litellm_log_setting != None: | |
if litellm_log_setting.upper() == "INFO": | |
from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
import logging | |
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS | |
verbose_router_logger.setLevel( | |
level=logging.INFO | |
) # set router logs to info | |
verbose_proxy_logger.setLevel( | |
level=logging.INFO | |
) # set proxy logs to info | |
elif litellm_log_setting.upper() == "DEBUG": | |
from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
import logging | |
verbose_router_logger.setLevel( | |
level=logging.DEBUG | |
) # set router logs to info | |
verbose_proxy_logger.setLevel( | |
level=logging.DEBUG | |
) # set proxy logs to debug | |
litellm.set_verbose = True | |
dynamic_config = {"general": {}, user_model: {}} | |
if config: | |
( | |
llm_router, | |
llm_model_list, | |
general_settings, | |
) = await proxy_config.load_config(router=llm_router, config_file_path=config) | |
if headers: # model-specific param | |
user_headers = headers | |
dynamic_config[user_model]["headers"] = headers | |
if api_base: # model-specific param | |
user_api_base = api_base | |
dynamic_config[user_model]["api_base"] = api_base | |
if api_version: | |
os.environ[ | |
"AZURE_API_VERSION" | |
] = api_version # set this for azure - litellm can read this from the env | |
if max_tokens: # model-specific param | |
user_max_tokens = max_tokens | |
dynamic_config[user_model]["max_tokens"] = max_tokens | |
if temperature: # model-specific param | |
user_temperature = temperature | |
dynamic_config[user_model]["temperature"] = temperature | |
if request_timeout: | |
user_request_timeout = request_timeout | |
dynamic_config[user_model]["request_timeout"] = request_timeout | |
if alias: # model-specific param | |
dynamic_config[user_model]["alias"] = alias | |
if drop_params == True: # litellm-specific param | |
litellm.drop_params = True | |
dynamic_config["general"]["drop_params"] = True | |
if add_function_to_prompt == True: # litellm-specific param | |
litellm.add_function_to_prompt = True | |
dynamic_config["general"]["add_function_to_prompt"] = True | |
if max_budget: # litellm-specific param | |
litellm.max_budget = max_budget | |
dynamic_config["general"]["max_budget"] = max_budget | |
if experimental: | |
pass | |
user_telemetry = telemetry | |
usage_telemetry(feature="local_proxy_server") | |
# for streaming | |
def data_generator(response): | |
verbose_proxy_logger.debug("inside generator") | |
for chunk in response: | |
verbose_proxy_logger.debug(f"returned chunk: {chunk}") | |
try: | |
yield f"data: {json.dumps(chunk.dict())}\n\n" | |
except: | |
yield f"data: {json.dumps(chunk)}\n\n" | |
async def async_data_generator(response, user_api_key_dict): | |
verbose_proxy_logger.debug("inside generator") | |
try: | |
start_time = time.time() | |
async for chunk in response: | |
verbose_proxy_logger.debug(f"returned chunk: {chunk}") | |
try: | |
yield f"data: {json.dumps(chunk.dict())}\n\n" | |
except Exception as e: | |
yield f"data: {str(e)}\n\n" | |
### ALERTING ### | |
end_time = time.time() | |
asyncio.create_task( | |
proxy_logging_obj.response_taking_too_long( | |
start_time=start_time, end_time=end_time, type="slow_response" | |
) | |
) | |
# Streaming is done, yield the [DONE] chunk | |
done_message = "[DONE]" | |
yield f"data: {done_message}\n\n" | |
except Exception as e: | |
yield f"data: {str(e)}\n\n" | |
def get_litellm_model_info(model: dict = {}): | |
model_info = model.get("model_info", {}) | |
model_to_lookup = model.get("litellm_params", {}).get("model", None) | |
try: | |
if "azure" in model_to_lookup: | |
model_to_lookup = model_info.get("base_model", None) | |
litellm_model_info = litellm.get_model_info(model_to_lookup) | |
return litellm_model_info | |
except: | |
# this should not block returning on /model/info | |
# if litellm does not have info on the model it should return {} | |
return {} | |
def parse_cache_control(cache_control): | |
cache_dict = {} | |
directives = cache_control.split(", ") | |
for directive in directives: | |
if "=" in directive: | |
key, value = directive.split("=") | |
cache_dict[key] = value | |
else: | |
cache_dict[directive] = True | |
return cache_dict | |
async def startup_event(): | |
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings | |
import json | |
### LOAD MASTER KEY ### | |
# check if master key set in environment - load from there | |
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None) | |
# check if DATABASE_URL in environment - load from there | |
if prisma_client is None: | |
prisma_setup(database_url=os.getenv("DATABASE_URL")) | |
### LOAD CONFIG ### | |
worker_config = litellm.get_secret("WORKER_CONFIG") | |
verbose_proxy_logger.debug(f"worker_config: {worker_config}") | |
# check if it's a valid file path | |
if os.path.isfile(worker_config): | |
if proxy_config.is_yaml(config_file_path=worker_config): | |
( | |
llm_router, | |
llm_model_list, | |
general_settings, | |
) = await proxy_config.load_config( | |
router=llm_router, config_file_path=worker_config | |
) | |
else: | |
await initialize(**worker_config) | |
else: | |
# if not, assume it's a json string | |
worker_config = json.loads(os.getenv("WORKER_CONFIG")) | |
await initialize(**worker_config) | |
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made | |
if use_background_health_checks: | |
asyncio.create_task( | |
_run_background_health_check() | |
) # start the background health check coroutine. | |
verbose_proxy_logger.debug(f"prisma client - {prisma_client}") | |
if prisma_client is not None: | |
await prisma_client.connect() | |
if custom_db_client is not None: | |
await custom_db_client.connect() | |
if prisma_client is not None and master_key is not None: | |
# add master key to db | |
await generate_key_helper_fn( | |
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key | |
) | |
if custom_db_client is not None and master_key is not None: | |
# add master key to db | |
await generate_key_helper_fn( | |
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key | |
) | |
#### API ENDPOINTS #### | |
# if project requires model list | |
def model_list(): | |
global llm_model_list, general_settings | |
all_models = [] | |
if general_settings.get("infer_model_from_keys", False): | |
all_models = litellm.utils.get_valid_models() | |
if llm_model_list: | |
all_models = list(set(all_models + [m["model_name"] for m in llm_model_list])) | |
if user_model is not None: | |
all_models += [user_model] | |
verbose_proxy_logger.debug(f"all_models: {all_models}") | |
### CHECK OLLAMA MODELS ### | |
try: | |
response = requests.get("http://0.0.0.0:11434/api/tags") | |
models = response.json()["models"] | |
ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models] | |
all_models.extend(ollama_models) | |
except Exception as e: | |
pass | |
return dict( | |
data=[ | |
{ | |
"id": model, | |
"object": "model", | |
"created": 1677610602, | |
"owned_by": "openai", | |
} | |
for model in all_models | |
], | |
object="list", | |
) | |
async def completion( | |
request: Request, | |
fastapi_response: Response, | |
model: Optional[str] = None, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
background_tasks: BackgroundTasks = BackgroundTasks(), | |
): | |
global user_temperature, user_request_timeout, user_max_tokens, user_api_base | |
try: | |
body = await request.body() | |
body_str = body.decode() | |
try: | |
data = ast.literal_eval(body_str) | |
except: | |
data = json.loads(body_str) | |
data["user"] = data.get("user", user_api_key_dict.user_id) | |
data["model"] = ( | |
general_settings.get("completion_model", None) # server default | |
or user_model # model name passed via cli args | |
or model # for azure deployments | |
or data["model"] # default passed in http request | |
) | |
if user_model: | |
data["model"] = user_model | |
if "metadata" in data: | |
data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
data["metadata"]["headers"] = dict(request.headers) | |
else: | |
data["metadata"] = { | |
"user_api_key": user_api_key_dict.api_key, | |
"user_api_key_user_id": user_api_key_dict.user_id, | |
} | |
data["metadata"]["headers"] = dict(request.headers) | |
# override with user settings, these are params passed via cli | |
if user_temperature: | |
data["temperature"] = user_temperature | |
if user_request_timeout: | |
data["request_timeout"] = user_request_timeout | |
if user_max_tokens: | |
data["max_tokens"] = user_max_tokens | |
if user_api_base: | |
data["api_base"] = user_api_base | |
### CALL HOOKS ### - modify incoming data before calling the model | |
data = await proxy_logging_obj.pre_call_hook( | |
user_api_key_dict=user_api_key_dict, data=data, call_type="completion" | |
) | |
start_time = time.time() | |
### ROUTE THE REQUESTs ### | |
router_model_names = ( | |
[m["model_name"] for m in llm_model_list] | |
if llm_model_list is not None | |
else [] | |
) | |
# skip router if user passed their key | |
if "api_key" in data: | |
response = await litellm.atext_completion(**data) | |
elif ( | |
llm_router is not None and data["model"] in router_model_names | |
): # model in router model list | |
response = await llm_router.atext_completion(**data) | |
elif ( | |
llm_router is not None | |
and llm_router.model_group_alias is not None | |
and data["model"] in llm_router.model_group_alias | |
): # model set in model_group_alias | |
response = await llm_router.atext_completion(**data) | |
elif ( | |
llm_router is not None and data["model"] in llm_router.deployment_names | |
): # model in router deployments, calling a specific deployment on the router | |
response = await llm_router.atext_completion( | |
**data, specific_deployment=True | |
) | |
else: # router is not set | |
response = await litellm.atext_completion(**data) | |
if hasattr(response, "_hidden_params"): | |
model_id = response._hidden_params.get("model_id", None) or "" | |
else: | |
model_id = "" | |
verbose_proxy_logger.debug(f"final response: {response}") | |
if ( | |
"stream" in data and data["stream"] == True | |
): # use generate_responses to stream responses | |
custom_headers = {"x-litellm-model-id": model_id} | |
return StreamingResponse( | |
async_data_generator( | |
user_api_key_dict=user_api_key_dict, | |
response=response, | |
), | |
media_type="text/event-stream", | |
headers=custom_headers, | |
) | |
### ALERTING ### | |
end_time = time.time() | |
asyncio.create_task( | |
proxy_logging_obj.response_taking_too_long( | |
start_time=start_time, end_time=end_time, type="slow_response" | |
) | |
) | |
fastapi_response.headers["x-litellm-model-id"] = model_id | |
return response | |
except Exception as e: | |
verbose_proxy_logger.debug(f"EXCEPTION RAISED IN PROXY MAIN.PY") | |
verbose_proxy_logger.debug( | |
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" | |
) | |
traceback.print_exc() | |
error_traceback = traceback.format_exc() | |
error_msg = f"{str(e)}\n\n{error_traceback}" | |
try: | |
status = e.status_code # type: ignore | |
except: | |
status = 500 | |
raise HTTPException(status_code=status, detail=error_msg) | |
# azure compatible endpoint | |
async def chat_completion( | |
request: Request, | |
fastapi_response: Response, | |
model: Optional[str] = None, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
background_tasks: BackgroundTasks = BackgroundTasks(), | |
): | |
global general_settings, user_debug, proxy_logging_obj, llm_model_list | |
try: | |
data = {} | |
body = await request.body() | |
body_str = body.decode() | |
try: | |
data = ast.literal_eval(body_str) | |
except: | |
data = json.loads(body_str) | |
# Include original request and headers in the data | |
data["proxy_server_request"] = { | |
"url": str(request.url), | |
"method": request.method, | |
"headers": dict(request.headers), | |
"body": copy.copy(data), # use copy instead of deepcopy | |
} | |
## Cache Controls | |
headers = request.headers | |
verbose_proxy_logger.debug(f"Request Headers: {headers}") | |
cache_control_header = headers.get("Cache-Control", None) | |
if cache_control_header: | |
cache_dict = parse_cache_control(cache_control_header) | |
data["ttl"] = cache_dict.get("s-maxage") | |
verbose_proxy_logger.debug(f"receiving data: {data}") | |
data["model"] = ( | |
general_settings.get("completion_model", None) # server default | |
or user_model # model name passed via cli args | |
or model # for azure deployments | |
or data["model"] # default passed in http request | |
) | |
# users can pass in 'user' param to /chat/completions. Don't override it | |
if data.get("user", None) is None and user_api_key_dict.user_id is not None: | |
# if users are using user_api_key_auth, set `user` in `data` | |
data["user"] = user_api_key_dict.user_id | |
if "metadata" in data: | |
verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}') | |
data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
data["metadata"]["headers"] = dict(request.headers) | |
else: | |
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} | |
data["metadata"]["headers"] = dict(request.headers) | |
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
global user_temperature, user_request_timeout, user_max_tokens, user_api_base | |
# override with user settings, these are params passed via cli | |
if user_temperature: | |
data["temperature"] = user_temperature | |
if user_request_timeout: | |
data["request_timeout"] = user_request_timeout | |
if user_max_tokens: | |
data["max_tokens"] = user_max_tokens | |
if user_api_base: | |
data["api_base"] = user_api_base | |
### CALL HOOKS ### - modify incoming data before calling the model | |
data = await proxy_logging_obj.pre_call_hook( | |
user_api_key_dict=user_api_key_dict, data=data, call_type="completion" | |
) | |
start_time = time.time() | |
### ROUTE THE REQUEST ### | |
router_model_names = ( | |
[m["model_name"] for m in llm_model_list] | |
if llm_model_list is not None | |
else [] | |
) | |
# skip router if user passed their key | |
if "api_key" in data: | |
response = await litellm.acompletion(**data) | |
elif "user_config" in data: | |
# initialize a new router instance. make request using this Router | |
router_config = data.pop("user_config") | |
user_router = litellm.Router(**router_config) | |
response = await user_router.acompletion(**data) | |
elif ( | |
llm_router is not None and data["model"] in router_model_names | |
): # model in router model list | |
response = await llm_router.acompletion(**data) | |
elif ( | |
llm_router is not None | |
and llm_router.model_group_alias is not None | |
and data["model"] in llm_router.model_group_alias | |
): # model set in model_group_alias | |
response = await llm_router.acompletion(**data) | |
elif ( | |
llm_router is not None and data["model"] in llm_router.deployment_names | |
): # model in router deployments, calling a specific deployment on the router | |
response = await llm_router.acompletion(**data, specific_deployment=True) | |
else: # router is not set | |
response = await litellm.acompletion(**data) | |
if hasattr(response, "_hidden_params"): | |
model_id = response._hidden_params.get("model_id", None) or "" | |
else: | |
model_id = "" | |
if ( | |
"stream" in data and data["stream"] == True | |
): # use generate_responses to stream responses | |
custom_headers = {"x-litellm-model-id": model_id} | |
return StreamingResponse( | |
async_data_generator( | |
user_api_key_dict=user_api_key_dict, | |
response=response, | |
), | |
media_type="text/event-stream", | |
headers=custom_headers, | |
) | |
### ALERTING ### | |
end_time = time.time() | |
asyncio.create_task( | |
proxy_logging_obj.response_taking_too_long( | |
start_time=start_time, end_time=end_time, type="slow_response" | |
) | |
) | |
fastapi_response.headers["x-litellm-model-id"] = model_id | |
return response | |
except Exception as e: | |
traceback.print_exc() | |
await proxy_logging_obj.post_call_failure_hook( | |
user_api_key_dict=user_api_key_dict, original_exception=e | |
) | |
verbose_proxy_logger.debug( | |
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" | |
) | |
router_model_names = ( | |
[m["model_name"] for m in llm_model_list] | |
if llm_model_list is not None | |
else [] | |
) | |
if llm_router is not None and data.get("model", "") in router_model_names: | |
verbose_proxy_logger.debug("Results from router") | |
verbose_proxy_logger.debug("\nRouter stats") | |
verbose_proxy_logger.debug("\nTotal Calls made") | |
for key, value in llm_router.total_calls.items(): | |
verbose_proxy_logger.debug(f"{key}: {value}") | |
verbose_proxy_logger.debug("\nSuccess Calls made") | |
for key, value in llm_router.success_calls.items(): | |
verbose_proxy_logger.debug(f"{key}: {value}") | |
verbose_proxy_logger.debug("\nFail Calls made") | |
for key, value in llm_router.fail_calls.items(): | |
verbose_proxy_logger.debug(f"{key}: {value}") | |
if user_debug: | |
traceback.print_exc() | |
if isinstance(e, HTTPException): | |
raise e | |
else: | |
error_traceback = traceback.format_exc() | |
error_msg = f"{str(e)}\n\n{error_traceback}" | |
try: | |
status = e.status_code # type: ignore | |
except: | |
status = 500 | |
raise HTTPException(status_code=status, detail=error_msg) | |
async def embeddings( | |
request: Request, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
background_tasks: BackgroundTasks = BackgroundTasks(), | |
): | |
global proxy_logging_obj | |
try: | |
# Use orjson to parse JSON data, orjson speeds up requests significantly | |
body = await request.body() | |
data = orjson.loads(body) | |
# Include original request and headers in the data | |
data["proxy_server_request"] = { | |
"url": str(request.url), | |
"method": request.method, | |
"headers": dict(request.headers), | |
"body": copy.copy(data), # use copy instead of deepcopy | |
} | |
if data.get("user", None) is None and user_api_key_dict.user_id is not None: | |
data["user"] = user_api_key_dict.user_id | |
data["model"] = ( | |
general_settings.get("embedding_model", None) # server default | |
or user_model # model name passed via cli args | |
or data["model"] # default passed in http request | |
) | |
if user_model: | |
data["model"] = user_model | |
if "metadata" in data: | |
data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
data["metadata"]["headers"] = dict(request.headers) | |
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
else: | |
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} | |
data["metadata"]["headers"] = dict(request.headers) | |
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
router_model_names = ( | |
[m["model_name"] for m in llm_model_list] | |
if llm_model_list is not None | |
else [] | |
) | |
if ( | |
"input" in data | |
and isinstance(data["input"], list) | |
and isinstance(data["input"][0], list) | |
and isinstance(data["input"][0][0], int) | |
): # check if array of tokens passed in | |
# check if non-openai/azure model called - e.g. for langchain integration | |
if llm_model_list is not None and data["model"] in router_model_names: | |
for m in llm_model_list: | |
if m["model_name"] == data["model"] and ( | |
m["litellm_params"]["model"] in litellm.open_ai_embedding_models | |
or m["litellm_params"]["model"].startswith("azure/") | |
): | |
pass | |
else: | |
# non-openai/azure embedding model called with token input | |
input_list = [] | |
for i in data["input"]: | |
input_list.append( | |
litellm.decode(model="gpt-3.5-turbo", tokens=i) | |
) | |
data["input"] = input_list | |
break | |
### CALL HOOKS ### - modify incoming data / reject request before calling the model | |
data = await proxy_logging_obj.pre_call_hook( | |
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" | |
) | |
start_time = time.time() | |
## ROUTE TO CORRECT ENDPOINT ## | |
# skip router if user passed their key | |
if "api_key" in data: | |
response = await litellm.aembedding(**data) | |
elif "user_config" in data: | |
# initialize a new router instance. make request using this Router | |
router_config = data.pop("user_config") | |
user_router = litellm.Router(**router_config) | |
response = await user_router.aembedding(**data) | |
elif ( | |
llm_router is not None and data["model"] in router_model_names | |
): # model in router model list | |
response = await llm_router.aembedding(**data) | |
elif ( | |
llm_router is not None | |
and llm_router.model_group_alias is not None | |
and data["model"] in llm_router.model_group_alias | |
): # model set in model_group_alias | |
response = await llm_router.aembedding( | |
**data | |
) # ensure this goes the llm_router, router will do the correct alias mapping | |
elif ( | |
llm_router is not None and data["model"] in llm_router.deployment_names | |
): # model in router deployments, calling a specific deployment on the router | |
response = await llm_router.aembedding(**data, specific_deployment=True) | |
else: | |
response = await litellm.aembedding(**data) | |
### ALERTING ### | |
end_time = time.time() | |
asyncio.create_task( | |
proxy_logging_obj.response_taking_too_long( | |
start_time=start_time, end_time=end_time, type="slow_response" | |
) | |
) | |
return response | |
except Exception as e: | |
await proxy_logging_obj.post_call_failure_hook( | |
user_api_key_dict=user_api_key_dict, original_exception=e | |
) | |
traceback.print_exc() | |
if isinstance(e, HTTPException): | |
raise e | |
else: | |
error_traceback = traceback.format_exc() | |
error_msg = f"{str(e)}\n\n{error_traceback}" | |
try: | |
status = e.status_code # type: ignore | |
except: | |
status = 500 | |
raise HTTPException(status_code=status, detail=error_msg) | |
async def image_generation( | |
request: Request, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
background_tasks: BackgroundTasks = BackgroundTasks(), | |
): | |
global proxy_logging_obj | |
try: | |
# Use orjson to parse JSON data, orjson speeds up requests significantly | |
body = await request.body() | |
data = orjson.loads(body) | |
# Include original request and headers in the data | |
data["proxy_server_request"] = { | |
"url": str(request.url), | |
"method": request.method, | |
"headers": dict(request.headers), | |
"body": copy.copy(data), # use copy instead of deepcopy | |
} | |
if data.get("user", None) is None and user_api_key_dict.user_id is not None: | |
data["user"] = user_api_key_dict.user_id | |
data["model"] = ( | |
general_settings.get("image_generation_model", None) # server default | |
or user_model # model name passed via cli args | |
or data["model"] # default passed in http request | |
) | |
if user_model: | |
data["model"] = user_model | |
if "metadata" in data: | |
data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
data["metadata"]["headers"] = dict(request.headers) | |
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
else: | |
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} | |
data["metadata"]["headers"] = dict(request.headers) | |
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
router_model_names = ( | |
[m["model_name"] for m in llm_model_list] | |
if llm_model_list is not None | |
else [] | |
) | |
### CALL HOOKS ### - modify incoming data / reject request before calling the model | |
data = await proxy_logging_obj.pre_call_hook( | |
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" | |
) | |
start_time = time.time() | |
## ROUTE TO CORRECT ENDPOINT ## | |
# skip router if user passed their key | |
if "api_key" in data: | |
response = await litellm.aimage_generation(**data) | |
elif ( | |
llm_router is not None and data["model"] in router_model_names | |
): # model in router model list | |
response = await llm_router.aimage_generation(**data) | |
elif ( | |
llm_router is not None and data["model"] in llm_router.deployment_names | |
): # model in router deployments, calling a specific deployment on the router | |
response = await llm_router.aimage_generation( | |
**data, specific_deployment=True | |
) | |
elif ( | |
llm_router is not None | |
and llm_router.model_group_alias is not None | |
and data["model"] in llm_router.model_group_alias | |
): # model set in model_group_alias | |
response = await llm_router.aimage_generation( | |
**data | |
) # ensure this goes the llm_router, router will do the correct alias mapping | |
else: | |
response = await litellm.aimage_generation(**data) | |
### ALERTING ### | |
end_time = time.time() | |
asyncio.create_task( | |
proxy_logging_obj.response_taking_too_long( | |
start_time=start_time, end_time=end_time, type="slow_response" | |
) | |
) | |
return response | |
except Exception as e: | |
await proxy_logging_obj.post_call_failure_hook( | |
user_api_key_dict=user_api_key_dict, original_exception=e | |
) | |
traceback.print_exc() | |
if isinstance(e, HTTPException): | |
raise e | |
else: | |
error_traceback = traceback.format_exc() | |
error_msg = f"{str(e)}\n\n{error_traceback}" | |
try: | |
status = e.status_code # type: ignore | |
except: | |
status = 500 | |
raise HTTPException(status_code=status, detail=error_msg) | |
#### KEY MANAGEMENT #### | |
async def generate_key_fn( | |
request: Request, | |
data: GenerateKeyRequest, | |
Authorization: Optional[str] = Header(None), | |
): | |
""" | |
Generate an API key based on the provided data. | |
Docs: https://docs.litellm.ai/docs/proxy/virtual_keys | |
Parameters: | |
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)** | |
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) | |
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models | |
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml | |
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend | |
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. | |
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "[email protected]" } | |
Returns: | |
- key: (str) The generated api key | |
- expires: (datetime) Datetime object for when key expires. | |
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. | |
""" | |
verbose_proxy_logger.debug("entered /key/generate") | |
data_json = data.json() # type: ignore | |
response = await generate_key_helper_fn(**data_json) | |
return GenerateKeyResponse( | |
key=response["token"], expires=response["expires"], user_id=response["user_id"] | |
) | |
async def update_key_fn(request: Request, data: UpdateKeyRequest): | |
""" | |
Update an existing key | |
""" | |
global prisma_client | |
try: | |
data_json: dict = data.json() | |
key = data_json.pop("key") | |
# get the row from db | |
if prisma_client is None: | |
raise Exception("Not connected to DB!") | |
non_default_values = {k: v for k, v in data_json.items() if v is not None} | |
response = await prisma_client.update_data( | |
token=key, data={**non_default_values, "token": key} | |
) | |
return {"key": key, **non_default_values} | |
# update based on remaining passed in values | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail={"error": str(e)}, | |
) | |
async def delete_key_fn(request: Request, data: DeleteKeyRequest): | |
try: | |
keys = data.keys | |
deleted_keys = await delete_verification_token(tokens=keys) | |
assert len(keys) == deleted_keys | |
return {"deleted_keys": keys} | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail={"error": str(e)}, | |
) | |
async def info_key_fn( | |
key: str = fastapi.Query(..., description="Key in the request parameters") | |
): | |
global prisma_client | |
try: | |
if prisma_client is None: | |
raise Exception( | |
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" | |
) | |
key_info = await prisma_client.get_data(token=key) | |
return {"key": key, "info": key_info} | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail={"error": str(e)}, | |
) | |
#### USER MANAGEMENT #### | |
async def new_user(data: NewUserRequest): | |
""" | |
Use this to create a new user with a budget. | |
Returns user id, budget + new key. | |
Parameters: | |
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. | |
- max_budget: Optional[float] - Specify max budget for a given user. | |
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)** | |
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) | |
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models | |
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml | |
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend | |
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. | |
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "[email protected]" } | |
Returns: | |
- key: (str) The generated api key | |
- expires: (datetime) Datetime object for when key expires. | |
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. | |
- max_budget: (float|None) Max budget for given user. | |
""" | |
data_json = data.json() # type: ignore | |
response = await generate_key_helper_fn(**data_json) | |
return NewUserResponse( | |
key=response["token"], | |
expires=response["expires"], | |
user_id=response["user_id"], | |
max_budget=response["max_budget"], | |
) | |
async def user_auth(request: Request): | |
""" | |
Allows UI ("https://dashboard.litellm.ai/", or self-hosted - os.getenv("LITELLM_HOSTED_UI")) to request a magic link to be sent to user email, for auth to proxy. | |
Only allows emails from accepted email subdomains. | |
Rate limit: 1 request every 60s. | |
Only works, if you enable 'allow_user_auth' in general settings: | |
e.g.: | |
```yaml | |
general_settings: | |
allow_user_auth: true | |
``` | |
Requirements: | |
SMTP server details saved in .env: | |
- os.environ["SMTP_HOST"] | |
- os.environ["SMTP_PORT"] | |
- os.environ["SMTP_USERNAME"] | |
- os.environ["SMTP_PASSWORD"] | |
- os.environ["SMTP_SENDER_EMAIL"] | |
""" | |
global prisma_client | |
data = await request.json() # type: ignore | |
user_email = data["user_email"] | |
page_params = data["page"] | |
if user_email is None: | |
raise HTTPException(status_code=400, detail="User email is none") | |
if prisma_client is None: # if no db connected, raise an error | |
raise Exception("No connected db.") | |
### Check if user email in user table | |
response = await prisma_client.get_generic_data( | |
key="user_email", value=user_email, table_name="users" | |
) | |
### if so - generate a 24 hr key with that user id | |
if response is not None: | |
user_id = response.user_id | |
response = await generate_key_helper_fn( | |
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id} # type: ignore | |
) | |
else: ### else - create new user | |
response = await generate_key_helper_fn( | |
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_email": user_email} # type: ignore | |
) | |
base_url = os.getenv("LITELLM_HOSTED_UI", "https://dashboard.litellm.ai/") | |
params = { | |
"sender_name": "LiteLLM Proxy", | |
"sender_email": os.getenv("SMTP_SENDER_EMAIL"), | |
"receiver_email": user_email, | |
"subject": "Your Magic Link", | |
"html": f"<strong> Follow this link, to login:\n\n{base_url}user/?token={response['token']}&user_id={response['user_id']}&page={page_params}</strong>", | |
} | |
await send_email(**params) | |
return "Email sent!" | |
async def user_info( | |
user_id: str = fastapi.Query(..., description="User ID in the request parameters") | |
): | |
""" | |
Use this to get user information. (user row + all user key info) | |
""" | |
global prisma_client | |
try: | |
if prisma_client is None: | |
raise Exception( | |
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" | |
) | |
## GET USER ROW ## | |
user_info = await prisma_client.get_data(user_id=user_id) | |
## GET ALL KEYS ## | |
keys = await prisma_client.get_data( | |
user_id=user_id, table_name="key", query_type="find_all" | |
) | |
return {"user_id": user_id, "user_info": user_info, "keys": keys} | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail={"error": str(e)}, | |
) | |
async def user_update(request: Request): | |
""" | |
[TODO]: Use this to update user budget | |
""" | |
pass | |
#### MODEL MANAGEMENT #### | |
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 | |
async def add_new_model(model_params: ModelParams): | |
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config | |
try: | |
# Load existing config | |
config = await proxy_config.get_config() | |
verbose_proxy_logger.debug(f"User config path: {user_config_file_path}") | |
verbose_proxy_logger.debug(f"Loaded config: {config}") | |
# Add the new model to the config | |
model_info = model_params.model_info.json() | |
model_info = {k: v for k, v in model_info.items() if v is not None} | |
config["model_list"].append( | |
{ | |
"model_name": model_params.model_name, | |
"litellm_params": model_params.litellm_params, | |
"model_info": model_info, | |
} | |
) | |
verbose_proxy_logger.debug(f"updated model list: {config['model_list']}") | |
# Save new config | |
await proxy_config.save_config(new_config=config) | |
return {"message": "Model added successfully"} | |
except Exception as e: | |
traceback.print_exc() | |
if isinstance(e, HTTPException): | |
raise e | |
else: | |
raise HTTPException( | |
status_code=500, detail=f"Internal Server Error: {str(e)}" | |
) | |
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info | |
async def model_info_v1(request: Request): | |
global llm_model_list, general_settings, user_config_file_path, proxy_config | |
# Load existing config | |
config = await proxy_config.get_config() | |
all_models = config["model_list"] | |
for model in all_models: | |
# provided model_info in config.yaml | |
model_info = model.get("model_info", {}) | |
# read litellm model_prices_and_context_window.json to get the following: | |
# input_cost_per_token, output_cost_per_token, max_tokens | |
litellm_model_info = get_litellm_model_info(model=model) | |
for k, v in litellm_model_info.items(): | |
if k not in model_info: | |
model_info[k] = v | |
model["model_info"] = model_info | |
# don't return the api key | |
model["litellm_params"].pop("api_key", None) | |
verbose_proxy_logger.debug(f"all_models: {all_models}") | |
return {"data": all_models} | |
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 | |
async def delete_model(model_info: ModelInfoDelete): | |
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config | |
try: | |
if not os.path.exists(user_config_file_path): | |
raise HTTPException(status_code=404, detail="Config file does not exist.") | |
# Load existing config | |
config = await proxy_config.get_config() | |
# If model_list is not in the config, nothing can be deleted | |
if len(config.get("model_list", [])) == 0: | |
raise HTTPException( | |
status_code=400, detail="No model list available in the config." | |
) | |
# Check if the model with the specified model_id exists | |
model_to_delete = None | |
for model in config["model_list"]: | |
if model.get("model_info", {}).get("id", None) == model_info.id: | |
model_to_delete = model | |
break | |
# If the model was not found, return an error | |
if model_to_delete is None: | |
raise HTTPException( | |
status_code=400, detail="Model with given model_id not found." | |
) | |
# Remove model from the list and save the updated config | |
config["model_list"].remove(model_to_delete) | |
# Save updated config | |
config = await proxy_config.save_config(new_config=config) | |
return {"message": "Model deleted successfully"} | |
except HTTPException as e: | |
# Re-raise the HTTP exceptions to be handled by FastAPI | |
raise | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") | |
#### EXPERIMENTAL QUEUING #### | |
async def _litellm_chat_completions_worker(data, user_api_key_dict): | |
""" | |
worker to make litellm completions calls | |
""" | |
while True: | |
try: | |
### CALL HOOKS ### - modify incoming data before calling the model | |
data = await proxy_logging_obj.pre_call_hook( | |
user_api_key_dict=user_api_key_dict, data=data, call_type="completion" | |
) | |
verbose_proxy_logger.debug(f"_litellm_chat_completions_worker started") | |
### ROUTE THE REQUEST ### | |
router_model_names = ( | |
[m["model_name"] for m in llm_model_list] | |
if llm_model_list is not None | |
else [] | |
) | |
if ( | |
llm_router is not None and data["model"] in router_model_names | |
): # model in router model list | |
response = await llm_router.acompletion(**data) | |
elif ( | |
llm_router is not None and data["model"] in llm_router.deployment_names | |
): # model in router deployments, calling a specific deployment on the router | |
response = await llm_router.acompletion( | |
**data, specific_deployment=True | |
) | |
elif ( | |
llm_router is not None | |
and llm_router.model_group_alias is not None | |
and data["model"] in llm_router.model_group_alias | |
): # model set in model_group_alias | |
response = await llm_router.acompletion(**data) | |
else: # router is not set | |
response = await litellm.acompletion(**data) | |
verbose_proxy_logger.debug(f"final response: {response}") | |
return response | |
except HTTPException as e: | |
verbose_proxy_logger.debug( | |
f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}" | |
) | |
if ( | |
e.status_code == 429 | |
and "Max parallel request limit reached" in e.detail | |
): | |
verbose_proxy_logger.debug(f"Max parallel request limit reached!") | |
timeout = litellm._calculate_retry_after( | |
remaining_retries=3, max_retries=3, min_timeout=1 | |
) | |
await asyncio.sleep(timeout) | |
else: | |
raise e | |
async def async_queue_request( | |
request: Request, | |
model: Optional[str] = None, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
background_tasks: BackgroundTasks = BackgroundTasks(), | |
): | |
global general_settings, user_debug, proxy_logging_obj | |
""" | |
v2 attempt at a background worker to handle queuing. | |
Just supports /chat/completion calls currently. | |
Now using a FastAPI background task + /chat/completions compatible endpoint | |
""" | |
try: | |
data = {} | |
data = await request.json() # type: ignore | |
# Include original request and headers in the data | |
data["proxy_server_request"] = { | |
"url": str(request.url), | |
"method": request.method, | |
"headers": dict(request.headers), | |
"body": copy.copy(data), # use copy instead of deepcopy | |
} | |
verbose_proxy_logger.debug(f"receiving data: {data}") | |
data["model"] = ( | |
general_settings.get("completion_model", None) # server default | |
or user_model # model name passed via cli args | |
or model # for azure deployments | |
or data["model"] # default passed in http request | |
) | |
# users can pass in 'user' param to /chat/completions. Don't override it | |
if data.get("user", None) is None and user_api_key_dict.user_id is not None: | |
# if users are using user_api_key_auth, set `user` in `data` | |
data["user"] = user_api_key_dict.user_id | |
if "metadata" in data: | |
verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}') | |
data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
data["metadata"]["headers"] = dict(request.headers) | |
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
else: | |
data["metadata"] = {"user_api_key": user_api_key_dict.api_key} | |
data["metadata"]["headers"] = dict(request.headers) | |
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
global user_temperature, user_request_timeout, user_max_tokens, user_api_base | |
# override with user settings, these are params passed via cli | |
if user_temperature: | |
data["temperature"] = user_temperature | |
if user_request_timeout: | |
data["request_timeout"] = user_request_timeout | |
if user_max_tokens: | |
data["max_tokens"] = user_max_tokens | |
if user_api_base: | |
data["api_base"] = user_api_base | |
response = await asyncio.wait_for( | |
_litellm_chat_completions_worker( | |
data=data, user_api_key_dict=user_api_key_dict | |
), | |
timeout=litellm.request_timeout, | |
) | |
if ( | |
"stream" in data and data["stream"] == True | |
): # use generate_responses to stream responses | |
return StreamingResponse( | |
async_data_generator( | |
user_api_key_dict=user_api_key_dict, response=response | |
), | |
media_type="text/event-stream", | |
) | |
return response | |
except Exception as e: | |
await proxy_logging_obj.post_call_failure_hook( | |
user_api_key_dict=user_api_key_dict, original_exception=e | |
) | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail={"error": str(e)}, | |
) | |
async def retrieve_server_log(request: Request): | |
filepath = os.path.expanduser("~/.ollama/logs/server.log") | |
return FileResponse(filepath) | |
#### BASIC ENDPOINTS #### | |
async def update_config(config_info: ConfigYAML): | |
""" | |
For Admin UI - allows admin to update config via UI | |
Currently supports modifying General Settings + LiteLLM settings | |
""" | |
global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj | |
try: | |
# Load existing config | |
config = await proxy_config.get_config() | |
backup_config = copy.deepcopy(config) | |
verbose_proxy_logger.debug(f"Loaded config: {config}") | |
# update the general settings | |
if config_info.general_settings is not None: | |
config.setdefault("general_settings", {}) | |
updated_general_settings = config_info.general_settings.dict( | |
exclude_none=True | |
) | |
config["general_settings"] = { | |
**updated_general_settings, | |
**config["general_settings"], | |
} | |
if config_info.environment_variables is not None: | |
config.setdefault("environment_variables", {}) | |
updated_environment_variables = config_info.environment_variables | |
config["environment_variables"] = { | |
**updated_environment_variables, | |
**config["environment_variables"], | |
} | |
# update the litellm settings | |
if config_info.litellm_settings is not None: | |
config.setdefault("litellm_settings", {}) | |
updated_litellm_settings = config_info.litellm_settings | |
config["litellm_settings"] = { | |
**updated_litellm_settings, | |
**config["litellm_settings"], | |
} | |
# Save the updated config | |
await proxy_config.save_config(new_config=config) | |
# Test new connections | |
## Slack | |
if "slack" in config.get("general_settings", {}).get("alerting", []): | |
await proxy_logging_obj.alerting_handler( | |
message="This is a test", level="Low" | |
) | |
return {"message": "Config updated successfully"} | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}") | |
async def config_yaml_endpoint(config_info: ConfigYAML): | |
""" | |
This is a mock endpoint, to show what you can set in config.yaml details in the Swagger UI. | |
Parameters: | |
The config.yaml object has the following attributes: | |
- **model_list**: *Optional[List[ModelParams]]* - A list of supported models on the server, along with model-specific configurations. ModelParams includes "model_name" (name of the model), "litellm_params" (litellm-specific parameters for the model), and "model_info" (additional info about the model such as id, mode, cost per token, etc). | |
- **litellm_settings**: *Optional[dict]*: Settings for the litellm module. You can specify multiple properties like "drop_params", "set_verbose", "api_base", "cache". | |
- **general_settings**: *Optional[ConfigGeneralSettings]*: General settings for the server like "completion_model" (default model for chat completion calls), "use_azure_key_vault" (option to load keys from azure key vault), "master_key" (key required for all calls to proxy), and others. | |
Please, refer to each class's description for a better understanding of the specific attributes within them. | |
Note: This is a mock endpoint primarily meant for demonstration purposes, and does not actually provide or change any configurations. | |
""" | |
return {"hello": "world"} | |
async def test_endpoint(request: Request): | |
""" | |
A test endpoint that pings the proxy server to check if it's healthy. | |
Parameters: | |
request (Request): The incoming request. | |
Returns: | |
dict: A dictionary containing the route of the request URL. | |
""" | |
# ping the proxy server to check if its healthy | |
return {"route": request.url.path} | |
async def health_endpoint( | |
request: Request, | |
model: Optional[str] = fastapi.Query( | |
None, description="Specify the model name (optional)" | |
), | |
): | |
""" | |
Check the health of all the endpoints in config.yaml | |
To run health checks in the background, add this to config.yaml: | |
``` | |
general_settings: | |
# ... other settings | |
background_health_checks: True | |
``` | |
else, the health checks will be run on models when /health is called. | |
""" | |
global health_check_results, use_background_health_checks, user_model | |
if llm_model_list is None: | |
# if no router set, check if user set a model using litellm --model ollama/llama2 | |
if user_model is not None: | |
healthy_endpoints, unhealthy_endpoints = await perform_health_check( | |
model_list=[], cli_model=user_model | |
) | |
return { | |
"healthy_endpoints": healthy_endpoints, | |
"unhealthy_endpoints": unhealthy_endpoints, | |
"healthy_count": len(healthy_endpoints), | |
"unhealthy_count": len(unhealthy_endpoints), | |
} | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail={"error": "Model list not initialized"}, | |
) | |
if use_background_health_checks: | |
return health_check_results | |
else: | |
healthy_endpoints, unhealthy_endpoints = await perform_health_check( | |
llm_model_list, model | |
) | |
return { | |
"healthy_endpoints": healthy_endpoints, | |
"unhealthy_endpoints": unhealthy_endpoints, | |
"healthy_count": len(healthy_endpoints), | |
"unhealthy_count": len(unhealthy_endpoints), | |
} | |
async def health_readiness(): | |
""" | |
Unprotected endpoint for checking if worker can receive requests | |
""" | |
global prisma_client | |
if prisma_client is not None: # if db passed in, check if it's connected | |
if prisma_client.db.is_connected() == True: | |
return {"status": "healthy", "db": "connected"} | |
else: | |
return {"status": "healthy", "db": "Not connected"} | |
raise HTTPException(status_code=503, detail="Service Unhealthy") | |
async def health_liveliness(): | |
""" | |
Unprotected endpoint for checking if worker is alive | |
""" | |
return "I'm alive!" | |
async def home(request: Request): | |
return "LiteLLM: RUNNING" | |
async def get_routes(): | |
""" | |
Get a list of available routes in the FastAPI application. | |
""" | |
routes = [] | |
for route in app.routes: | |
route_info = { | |
"path": route.path, | |
"methods": route.methods, | |
"name": route.name, | |
"endpoint": route.endpoint.__name__ if route.endpoint else None, | |
} | |
routes.append(route_info) | |
return {"routes": routes} | |
async def shutdown_event(): | |
global prisma_client, master_key, user_custom_auth | |
if prisma_client: | |
verbose_proxy_logger.debug("Disconnecting from Prisma") | |
await prisma_client.disconnect() | |
## RESET CUSTOM VARIABLES ## | |
cleanup_router_config_variables() | |
def cleanup_router_config_variables(): | |
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval | |
# Set all variables to None | |
master_key = None | |
user_config_file_path = None | |
otel_logging = None | |
user_custom_auth = None | |
user_custom_auth_path = None | |
use_background_health_checks = None | |
health_check_interval = None | |
app.include_router(router) | |