Spaces:
Sleeping
Sleeping
from litellm.proxy.db.base_client import CustomDB | |
from litellm.proxy._types import ( | |
DynamoDBArgs, | |
LiteLLM_VerificationToken, | |
LiteLLM_Config, | |
LiteLLM_UserTable, | |
) | |
from litellm import get_secret | |
from typing import Any, List, Literal, Optional, Union | |
import json | |
from datetime import datetime | |
class DynamoDBWrapper(CustomDB): | |
from aiodynamo.credentials import Credentials, StaticCredentials | |
credentials: Credentials | |
def __init__(self, database_arguments: DynamoDBArgs): | |
from aiodynamo.client import Client | |
from aiodynamo.credentials import Credentials, StaticCredentials | |
from aiodynamo.http.httpx import HTTPX | |
from aiodynamo.models import ( | |
Throughput, | |
KeySchema, | |
KeySpec, | |
KeyType, | |
PayPerRequest, | |
) | |
from yarl import URL | |
from aiodynamo.expressions import UpdateExpression, F, Value | |
from aiodynamo.models import ReturnValues | |
from aiodynamo.http.aiohttp import AIOHTTP | |
from aiohttp import ClientSession | |
self.throughput_type = None | |
if database_arguments.billing_mode == "PAY_PER_REQUEST": | |
self.throughput_type = PayPerRequest() | |
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT": | |
if ( | |
database_arguments.read_capacity_units is not None | |
and isinstance(database_arguments.read_capacity_units, int) | |
and database_arguments.write_capacity_units is not None | |
and isinstance(database_arguments.write_capacity_units, int) | |
): | |
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore | |
else: | |
raise Exception( | |
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}" | |
) | |
self.database_arguments = database_arguments | |
self.region_name = database_arguments.region_name | |
async def connect(self): | |
""" | |
Connect to DB, and creating / updating any tables | |
""" | |
from aiodynamo.client import Client | |
from aiodynamo.credentials import Credentials, StaticCredentials | |
from aiodynamo.http.httpx import HTTPX | |
from aiodynamo.models import ( | |
Throughput, | |
KeySchema, | |
KeySpec, | |
KeyType, | |
PayPerRequest, | |
) | |
from yarl import URL | |
from aiodynamo.expressions import UpdateExpression, F, Value | |
from aiodynamo.models import ReturnValues | |
from aiodynamo.http.aiohttp import AIOHTTP | |
from aiohttp import ClientSession | |
async with ClientSession() as session: | |
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) | |
## User | |
try: | |
error_occurred = False | |
table = client.table(self.database_arguments.user_table_name) | |
if not await table.exists(): | |
await table.create( | |
self.throughput_type, | |
KeySchema(hash_key=KeySpec("user_id", KeyType.string)), | |
) | |
except Exception as e: | |
error_occurred = True | |
if error_occurred == True: | |
raise Exception( | |
f"Failed to create table - {self.database_arguments.user_table_name}.\nPlease create a new table called {self.database_arguments.user_table_name}\nAND set `hash_key` as 'user_id'" | |
) | |
## Token | |
try: | |
error_occurred = False | |
table = client.table(self.database_arguments.key_table_name) | |
if not await table.exists(): | |
await table.create( | |
self.throughput_type, | |
KeySchema(hash_key=KeySpec("token", KeyType.string)), | |
) | |
except Exception as e: | |
error_occurred = True | |
if error_occurred == True: | |
raise Exception( | |
f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'" | |
) | |
## Config | |
try: | |
error_occurred = False | |
table = client.table(self.database_arguments.config_table_name) | |
if not await table.exists(): | |
await table.create( | |
self.throughput_type, | |
KeySchema(hash_key=KeySpec("param_name", KeyType.string)), | |
) | |
except Exception as e: | |
error_occurred = True | |
if error_occurred == True: | |
raise Exception( | |
f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'" | |
) | |
async def insert_data( | |
self, value: Any, table_name: Literal["user", "key", "config"] | |
): | |
from aiodynamo.client import Client | |
from aiodynamo.credentials import Credentials, StaticCredentials | |
from aiodynamo.http.httpx import HTTPX | |
from aiodynamo.models import ( | |
Throughput, | |
KeySchema, | |
KeySpec, | |
KeyType, | |
PayPerRequest, | |
) | |
from yarl import URL | |
from aiodynamo.expressions import UpdateExpression, F, Value | |
from aiodynamo.models import ReturnValues | |
from aiodynamo.http.aiohttp import AIOHTTP | |
from aiohttp import ClientSession | |
async with ClientSession() as session: | |
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) | |
table = None | |
if table_name == "user": | |
table = client.table(self.database_arguments.user_table_name) | |
elif table_name == "key": | |
table = client.table(self.database_arguments.key_table_name) | |
elif table_name == "config": | |
table = client.table(self.database_arguments.config_table_name) | |
for k, v in value.items(): | |
if isinstance(v, datetime): | |
value[k] = v.isoformat() | |
await table.put_item(item=value) | |
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): | |
from aiodynamo.client import Client | |
from aiodynamo.credentials import Credentials, StaticCredentials | |
from aiodynamo.http.httpx import HTTPX | |
from aiodynamo.models import ( | |
Throughput, | |
KeySchema, | |
KeySpec, | |
KeyType, | |
PayPerRequest, | |
) | |
from yarl import URL | |
from aiodynamo.expressions import UpdateExpression, F, Value | |
from aiodynamo.models import ReturnValues | |
from aiodynamo.http.aiohttp import AIOHTTP | |
from aiohttp import ClientSession | |
async with ClientSession() as session: | |
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) | |
table = None | |
key_name = None | |
if table_name == "user": | |
table = client.table(self.database_arguments.user_table_name) | |
key_name = "user_id" | |
elif table_name == "key": | |
table = client.table(self.database_arguments.key_table_name) | |
key_name = "token" | |
elif table_name == "config": | |
table = client.table(self.database_arguments.config_table_name) | |
key_name = "param_name" | |
response = await table.get_item({key_name: key}) | |
new_response: Any = None | |
if table_name == "user": | |
new_response = LiteLLM_UserTable(**response) | |
elif table_name == "key": | |
new_response = {} | |
for k, v in response.items(): # handle json string | |
if ( | |
(k == "aliases" or k == "config" or k == "metadata") | |
and v is not None | |
and isinstance(v, str) | |
): | |
new_response[k] = json.loads(v) | |
else: | |
new_response[k] = v | |
new_response = LiteLLM_VerificationToken(**new_response) | |
elif table_name == "config": | |
new_response = LiteLLM_Config(**response) | |
return new_response | |
async def update_data( | |
self, key: str, value: dict, table_name: Literal["user", "key", "config"] | |
): | |
from aiodynamo.client import Client | |
from aiodynamo.credentials import Credentials, StaticCredentials | |
from aiodynamo.http.httpx import HTTPX | |
from aiodynamo.models import ( | |
Throughput, | |
KeySchema, | |
KeySpec, | |
KeyType, | |
PayPerRequest, | |
) | |
from yarl import URL | |
from aiodynamo.expressions import UpdateExpression, F, Value | |
from aiodynamo.models import ReturnValues | |
from aiodynamo.http.aiohttp import AIOHTTP | |
from aiohttp import ClientSession | |
async with ClientSession() as session: | |
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) | |
table = None | |
key_name = None | |
try: | |
if table_name == "user": | |
table = client.table(self.database_arguments.user_table_name) | |
key_name = "user_id" | |
elif table_name == "key": | |
table = client.table(self.database_arguments.key_table_name) | |
key_name = "token" | |
elif table_name == "config": | |
table = client.table(self.database_arguments.config_table_name) | |
key_name = "param_name" | |
else: | |
raise Exception( | |
f"Invalid table name. Needs to be one of - {self.database_arguments.user_table_name}, {self.database_arguments.key_table_name}, {self.database_arguments.config_table_name}" | |
) | |
except Exception as e: | |
raise Exception(f"Error connecting to table - {str(e)}") | |
# Initialize an empty UpdateExpression | |
actions: List = [] | |
for k, v in value.items(): | |
# Convert datetime object to ISO8601 string | |
if isinstance(v, datetime): | |
v = v.isoformat() | |
# Accumulate updates | |
actions.append((F(k), Value(value=v))) | |
update_expression = UpdateExpression(set_updates=actions) | |
# Perform the update in DynamoDB | |
result = await table.update_item( | |
key={key_name: key}, | |
update_expression=update_expression, | |
return_values=ReturnValues.none, | |
) | |
return result | |
async def delete_data( | |
self, keys: List[str], table_name: Literal["user", "key", "config"] | |
): | |
""" | |
Not Implemented yet. | |
""" | |
return super().delete_data(keys, table_name) | |