Spaces:
Runtime error
Runtime error
"""SQLAlchemy wrapper around a database.""" | |
from __future__ import annotations | |
from typing import Any, Iterable, List, Optional | |
from sqlalchemy import MetaData, create_engine, inspect, select, text | |
from sqlalchemy.engine import Engine | |
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError | |
from sqlalchemy.schema import CreateTable | |
class SQLDatabase: | |
"""SQLAlchemy wrapper around a database.""" | |
def __init__( | |
self, | |
engine: Engine, | |
schema: Optional[str] = None, | |
metadata: Optional[MetaData] = None, | |
ignore_tables: Optional[List[str]] = None, | |
include_tables: Optional[List[str]] = None, | |
sample_rows_in_table_info: int = 3, | |
custom_table_info: Optional[dict] = None, | |
): | |
"""Create engine from database URI.""" | |
self._engine = engine | |
self._schema = schema | |
if include_tables and ignore_tables: | |
raise ValueError("Cannot specify both include_tables and ignore_tables") | |
self._inspector = inspect(self._engine) | |
self._all_tables = set(self._inspector.get_table_names(schema=schema)) | |
self._include_tables = set(include_tables) if include_tables else set() | |
if self._include_tables: | |
missing_tables = self._include_tables - self._all_tables | |
if missing_tables: | |
raise ValueError( | |
f"include_tables {missing_tables} not found in database" | |
) | |
self._ignore_tables = set(ignore_tables) if ignore_tables else set() | |
if self._ignore_tables: | |
missing_tables = self._ignore_tables - self._all_tables | |
if missing_tables: | |
raise ValueError( | |
f"ignore_tables {missing_tables} not found in database" | |
) | |
if not isinstance(sample_rows_in_table_info, int): | |
raise TypeError("sample_rows_in_table_info must be an integer") | |
self._sample_rows_in_table_info = sample_rows_in_table_info | |
self._custom_table_info = custom_table_info | |
if self._custom_table_info: | |
if not isinstance(self._custom_table_info, dict): | |
raise TypeError( | |
"table_info must be a dictionary with table names as keys and the " | |
"desired table info as values" | |
) | |
# only keep the tables that are also present in the database | |
intersection = set(self._custom_table_info).intersection(self._all_tables) | |
self._custom_table_info = dict( | |
(table, self._custom_table_info[table]) | |
for table in self._custom_table_info | |
if table in intersection | |
) | |
self._metadata = metadata or MetaData() | |
self._metadata.reflect(bind=self._engine) | |
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase: | |
"""Construct a SQLAlchemy engine from URI.""" | |
return cls(create_engine(database_uri), **kwargs) | |
def dialect(self) -> str: | |
"""Return string representation of dialect to use.""" | |
return self._engine.dialect.name | |
def get_table_names(self) -> Iterable[str]: | |
"""Get names of tables available.""" | |
if self._include_tables: | |
return self._include_tables | |
return self._all_tables - self._ignore_tables | |
def table_info(self) -> str: | |
"""Information about all tables in the database.""" | |
return self.get_table_info() | |
def get_table_info(self, table_names: Optional[List[str]] = None) -> str: | |
"""Get information about specified tables. | |
Follows best practices as specified in: Rajkumar et al, 2022 | |
(https://arxiv.org/abs/2204.00498) | |
If `sample_rows_in_table_info`, the specified number of sample rows will be | |
appended to each table description. This can increase performance as | |
demonstrated in the paper. | |
""" | |
all_table_names = self.get_table_names() | |
if table_names is not None: | |
missing_tables = set(table_names).difference(all_table_names) | |
if missing_tables: | |
raise ValueError(f"table_names {missing_tables} not found in database") | |
all_table_names = table_names | |
meta_tables = [ | |
tbl | |
for tbl in self._metadata.sorted_tables | |
if tbl.name in set(all_table_names) | |
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) | |
] | |
tables = [] | |
for table in meta_tables: | |
if self._custom_table_info and table.name in self._custom_table_info: | |
tables.append(self._custom_table_info[table.name]) | |
continue | |
# add create table command | |
create_table = str(CreateTable(table).compile(self._engine)) | |
if self._sample_rows_in_table_info: | |
# build the select command | |
command = select(table).limit(self._sample_rows_in_table_info) | |
# save the columns in string format | |
columns_str = "\t".join([col.name for col in table.columns]) | |
try: | |
# get the sample rows | |
with self._engine.connect() as connection: | |
sample_rows = connection.execute(command) | |
# shorten values in the sample rows | |
sample_rows = list( | |
map(lambda ls: [str(i)[:100] for i in ls], sample_rows) | |
) | |
# save the sample rows in string format | |
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) | |
# in some dialects when there are no rows in the table a | |
# 'ProgrammingError' is returned | |
except ProgrammingError: | |
sample_rows_str = "" | |
table_info = ( | |
f"{create_table.rstrip()}\n" | |
f"/*\n" | |
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" | |
f"{columns_str}\n" | |
f"{sample_rows_str}\n" | |
f"*/" | |
) | |
# build final info for table | |
tables.append(table_info) | |
else: | |
tables.append(create_table) | |
final_str = "\n\n".join(tables) | |
return final_str | |
def run(self, command: str, fetch: str = "all") -> str: | |
"""Execute a SQL command and return a string representing the results. | |
If the statement returns rows, a string of the results is returned. | |
If the statement returns no rows, an empty string is returned. | |
""" | |
with self._engine.begin() as connection: | |
if self._schema is not None: | |
connection.exec_driver_sql(f"SET search_path TO {self._schema}") | |
cursor = connection.execute(text(command)) | |
if cursor.returns_rows: | |
if fetch == "all": | |
data = cursor.fetchall() | |
column_names = [desc.name for desc in cursor.context.cursor.description] | |
result = {"column_names": column_names, "data": data} | |
elif fetch == "one": | |
result = cursor.fetchone()[0] | |
else: | |
raise ValueError("Fetch parameter must be either 'one' or 'all'") | |
return str(result) | |
return "" | |
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: | |
"""Get information about specified tables. | |
Follows best practices as specified in: Rajkumar et al, 2022 | |
(https://arxiv.org/abs/2204.00498) | |
If `sample_rows_in_table_info`, the specified number of sample rows will be | |
appended to each table description. This can increase performance as | |
demonstrated in the paper. | |
""" | |
try: | |
return self.get_table_info(table_names) | |
except ValueError as e: | |
"""Format the error message""" | |
return f"Error: {e}" | |
def run_no_throw(self, command: str, fetch: str = "all") -> str: | |
"""Execute a SQL command and return a string representing the results. | |
If the statement returns rows, a string of the results is returned. | |
If the statement returns no rows, an empty string is returned. | |
If the statement throws an error, the error message is returned. | |
""" | |
try: | |
return self.run(command, fetch) | |
except SQLAlchemyError as e: | |
"""Format the error message""" | |
return f"Error: {e}" | |