80 lines
3.4 KiB
Python
80 lines
3.4 KiB
Python
from typing import Optional, List
|
|
import json
|
|
|
|
from sqlalchemy import text
|
|
# from langchain import SQLDatabase
|
|
from langchain_community.utilities import SQLDatabase
|
|
|
|
|
|
class CustomDatabase(SQLDatabase):
|
|
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
|
# def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
|
connection = self._engine.connect()
|
|
all_table_names = self.get_usable_table_names()
|
|
if table_names is not None:
|
|
missing_tables = set(table_names).difference(all_table_names)
|
|
if missing_tables:
|
|
# raise ValueError(f"table_names {missing_tables} not found in database")
|
|
return f"Table {','.join(missing_tables)} can not be found in the database"
|
|
all_table_names = table_names
|
|
meta_tables = [
|
|
tbl
|
|
for tbl in self._metadata.sorted_tables
|
|
if tbl.name in set(all_table_names)
|
|
]
|
|
|
|
tables = []
|
|
for table in meta_tables:
|
|
table_name = table.name
|
|
column_names = table.columns.keys()
|
|
table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n"
|
|
for column_name in column_names:
|
|
if column_name not in ["ID", "img_name"]:
|
|
query = text(f"SELECT DISTINCT {column_name} FROM {table_name}")
|
|
result = connection.execute(query)
|
|
enum_values: List[str] = [row[0] for row in result.fetchall()]
|
|
column_info = f"{column_name}: {', '.join(enum_values)}\n"
|
|
table_info += column_info
|
|
|
|
# table_info = f"Table: {table_name}\n"
|
|
#
|
|
# if self._sample_rows_in_table_info:
|
|
# table_info += f"{self._get_sample_rows(table)}\n"
|
|
tables.append(table_info)
|
|
final_str = "\n\n".join(tables)
|
|
return final_str
|
|
|
|
def run(self, command: str, fetch: str = "all") -> str:
|
|
"""Execute a SQL command and return a string representing the results.
|
|
|
|
If the statement returns rows, a string of the results is returned.
|
|
If the statement returns no rows, an empty string is returned.
|
|
|
|
"""
|
|
with self._engine.begin() as connection:
|
|
if self._schema is not None:
|
|
if self.dialect == "snowflake":
|
|
connection.exec_driver_sql(
|
|
f"ALTER SESSION SET search_path='{self._schema}'"
|
|
)
|
|
elif self.dialect == "bigquery":
|
|
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
|
|
else:
|
|
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
|
cursor = connection.execute(text(command))
|
|
if cursor.rowcount:
|
|
if fetch == "all":
|
|
result = cursor.fetchall()
|
|
elif fetch == "one":
|
|
result = cursor.fetchone() # type: ignore
|
|
else:
|
|
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
|
|
|
# Convert columns values to string to avoid issues with sqlalchmey
|
|
# trunacating text
|
|
if isinstance(result, list):
|
|
return json.dumps([r[0] for r in result])
|
|
|
|
return json.dumps([result[0]])
|
|
return ""
|