Files
AiDA_Python/app/service/chat_robot/script/database.py
2024-05-29 11:12:59 +08:00

80 lines
3.3 KiB
Python

from typing import Optional, List
import json
from sqlalchemy import text
# from langchain import SQLDatabase
from langchain.utilities import SQLDatabase
class CustomDatabase(SQLDatabase):
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
# def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
connection = self._engine.connect()
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
# raise ValueError(f"table_names {missing_tables} not found in database")
return f"Table {','.join(missing_tables)} can not be found in the database"
all_table_names = table_names
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
]
tables = []
for table in meta_tables:
table_name = table.name
column_names = table.columns.keys()
table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n"
for column_name in column_names:
if column_name not in ["ID", "img_name"]:
query = text(f"SELECT DISTINCT {column_name} FROM {table_name}")
result = connection.execute(query)
enum_values: List[str] = [row[0] for row in result.fetchall()]
column_info = f"{column_name}: {', '.join(enum_values)}\n"
table_info += column_info
# table_info = f"Table: {table_name}\n"
#
# if self._sample_rows_in_table_info:
# table_info += f"{self._get_sample_rows(table)}\n"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def run(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.begin() as connection:
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
f"ALTER SESSION SET search_path='{self._schema}'"
)
elif self.dialect == "bigquery":
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
else:
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.execute(text(command))
if cursor.rowcount:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone() # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
# Convert columns values to string to avoid issues with sqlalchmey
# trunacating text
if isinstance(result, list):
return json.dumps([r[0] for r in result])
return json.dumps([result[0]])
return ""