feat chat robot 接口迁移
This commit is contained in:
79
app/service/chat_robot/script/database.py
Normal file
79
app/service/chat_robot/script/database.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Optional, List
|
||||
import json
|
||||
|
||||
from sqlalchemy import text
|
||||
# from langchain import SQLDatabase
|
||||
from langchain.utilities import SQLDatabase
|
||||
|
||||
|
||||
class CustomDatabase(SQLDatabase):
|
||||
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
||||
# def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
connection = self._engine.connect()
|
||||
all_table_names = self.get_usable_table_names()
|
||||
if table_names is not None:
|
||||
missing_tables = set(table_names).difference(all_table_names)
|
||||
if missing_tables:
|
||||
# raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
return f"Table {','.join(missing_tables)} can not be found in the database"
|
||||
all_table_names = table_names
|
||||
meta_tables = [
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
]
|
||||
|
||||
tables = []
|
||||
for table in meta_tables:
|
||||
table_name = table.name
|
||||
column_names = table.columns.keys()
|
||||
table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n"
|
||||
for column_name in column_names:
|
||||
if column_name not in ["ID", "img_name"]:
|
||||
query = text(f"SELECT DISTINCT {column_name} FROM {table_name}")
|
||||
result = connection.execute(query)
|
||||
enum_values: List[str] = [row[0] for row in result.fetchall()]
|
||||
column_info = f"{column_name}: {', '.join(enum_values)}\n"
|
||||
table_info += column_info
|
||||
|
||||
# table_info = f"Table: {table_name}\n"
|
||||
#
|
||||
# if self._sample_rows_in_table_info:
|
||||
# table_info += f"{self._get_sample_rows(table)}\n"
|
||||
tables.append(table_info)
|
||||
final_str = "\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
|
||||
"""
|
||||
with self._engine.begin() as connection:
|
||||
if self._schema is not None:
|
||||
if self.dialect == "snowflake":
|
||||
connection.exec_driver_sql(
|
||||
f"ALTER SESSION SET search_path='{self._schema}'"
|
||||
)
|
||||
elif self.dialect == "bigquery":
|
||||
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
|
||||
else:
|
||||
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
||||
cursor = connection.execute(text(command))
|
||||
if cursor.rowcount:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone() # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
|
||||
# Convert columns values to string to avoid issues with sqlalchmey
|
||||
# trunacating text
|
||||
if isinstance(result, list):
|
||||
return json.dumps([r[0] for r in result])
|
||||
|
||||
return json.dumps([result[0]])
|
||||
return ""
|
||||
Reference in New Issue
Block a user