File size: 8,376 Bytes
395201c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# import os, traceback
# from fastapi import FastAPI, Request, HTTPException
# from fastapi.routing import APIRouter
# from fastapi.responses import StreamingResponse, FileResponse
# from fastapi.middleware.cors import CORSMiddleware
# import json, sys
# from typing import Optional
# sys.path.insert(
# 0, os.path.abspath("../")
# ) # Adds the parent directory to the system path - for litellm local dev
# import litellm
# try:
# from litellm.deprecated_litellm_server.server_utils import set_callbacks, load_router_config, print_verbose
# except ImportError:
# from litellm.deprecated_litellm_server.server_utils import set_callbacks, load_router_config, print_verbose
# import dotenv
# dotenv.load_dotenv() # load env variables
# app = FastAPI(docs_url="/", title="LiteLLM API")
# router = APIRouter()
# origins = ["*"]
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# #### GLOBAL VARIABLES ####
# llm_router: Optional[litellm.Router] = None
# llm_model_list: Optional[list] = None
# server_settings: Optional[dict] = None
# set_callbacks() # sets litellm callbacks for logging if they exist in the environment
# if "CONFIG_FILE_PATH" in os.environ:
# llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
# else:
# llm_router, llm_model_list, server_settings = load_router_config(router=llm_router)
# #### API ENDPOINTS ####
# @router.get("/v1/models")
# @router.get("/models") # if project requires model list
# def model_list():
# all_models = litellm.utils.get_valid_models()
# if llm_model_list:
# all_models += llm_model_list
# return dict(
# data=[
# {
# "id": model,
# "object": "model",
# "created": 1677610602,
# "owned_by": "openai",
# }
# for model in all_models
# ],
# object="list",
# )
# # for streaming
# def data_generator(response):
# for chunk in response:
# yield f"data: {json.dumps(chunk)}\n\n"
# @router.post("/v1/completions")
# @router.post("/completions")
# async def completion(request: Request):
# data = await request.json()
# response = litellm.completion(
# **data
# )
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
# return StreamingResponse(data_generator(response), media_type='text/event-stream')
# return response
# @router.post("/v1/embeddings")
# @router.post("/embeddings")
# async def embedding(request: Request):
# try:
# data = await request.json()
# # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
# if os.getenv("AUTH_STRATEGY", None) == "DYNAMIC" and "authorization" in request.headers: # if users pass LLM api keys as part of header
# api_key = request.headers.get("authorization")
# api_key = api_key.replace("Bearer", "").strip() # type: ignore
# if len(api_key.strip()) > 0:
# api_key = api_key
# data["api_key"] = api_key
# response = litellm.embedding(
# **data
# )
# return response
# except Exception as e:
# error_traceback = traceback.format_exc()
# error_msg = f"{str(e)}\n\n{error_traceback}"
# return {"error": error_msg}
# @router.post("/v1/chat/completions")
# @router.post("/chat/completions")
# @router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
# async def chat_completion(request: Request, model: Optional[str] = None):
# global llm_model_list, server_settings
# try:
# data = await request.json()
# server_model = server_settings.get("completion_model", None) if server_settings else None
# data["model"] = server_model or model or data["model"]
# ## CHECK KEYS ##
# # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
# # env_validation = litellm.validate_environment(model=data["model"])
# # if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
# # if "authorization" in request.headers:
# # api_key = request.headers.get("authorization")
# # elif "api-key" in request.headers:
# # api_key = request.headers.get("api-key")
# # print(f"api_key in headers: {api_key}")
# # if " " in api_key:
# # api_key = api_key.split(" ")[1]
# # print(f"api_key split: {api_key}")
# # if len(api_key) > 0:
# # api_key = api_key
# # data["api_key"] = api_key
# # print(f"api_key in data: {api_key}")
# ## CHECK CONFIG ##
# if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]:
# for m in llm_model_list:
# if data["model"] == m["model_name"]:
# for key, value in m["litellm_params"].items():
# data[key] = value
# break
# response = litellm.completion(
# **data
# )
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
# return StreamingResponse(data_generator(response), media_type='text/event-stream')
# return response
# except Exception as e:
# error_traceback = traceback.format_exc()
# error_msg = f"{str(e)}\n\n{error_traceback}"
# # return {"error": error_msg}
# raise HTTPException(status_code=500, detail=error_msg)
# @router.post("/router/completions")
# async def router_completion(request: Request):
# global llm_router
# try:
# data = await request.json()
# if "model_list" in data:
# llm_router = litellm.Router(model_list=data.pop("model_list"))
# if llm_router is None:
# raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
# # openai.ChatCompletion.create replacement
# response = await llm_router.acompletion(model="gpt-3.5-turbo",
# messages=[{"role": "user", "content": "Hey, how's it going?"}])
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
# return StreamingResponse(data_generator(response), media_type='text/event-stream')
# return response
# except Exception as e:
# error_traceback = traceback.format_exc()
# error_msg = f"{str(e)}\n\n{error_traceback}"
# return {"error": error_msg}
# @router.post("/router/embedding")
# async def router_embedding(request: Request):
# global llm_router
# try:
# data = await request.json()
# if "model_list" in data:
# llm_router = litellm.Router(model_list=data.pop("model_list"))
# if llm_router is None:
# raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
# response = await llm_router.aembedding(model="gpt-3.5-turbo", # type: ignore
# messages=[{"role": "user", "content": "Hey, how's it going?"}])
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
# return StreamingResponse(data_generator(response), media_type='text/event-stream')
# return response
# except Exception as e:
# error_traceback = traceback.format_exc()
# error_msg = f"{str(e)}\n\n{error_traceback}"
# return {"error": error_msg}
# @router.get("/")
# async def home(request: Request):
# return "LiteLLM: RUNNING"
# app.include_router(router) |