from typing import Optional, List import json from sqlalchemy import text # from langchain import SQLDatabase from langchain_community.utilities import SQLDatabase class CustomDatabase(SQLDatabase): def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: # def get_table_info(self, table_names: Optional[List[str]] = None) -> str: connection = self._engine.connect() all_table_names = self.get_usable_table_names() if table_names is not None: missing_tables = set(table_names).difference(all_table_names) if missing_tables: # raise ValueError(f"table_names {missing_tables} not found in database") return f"Table {','.join(missing_tables)} can not be found in the database" all_table_names = table_names meta_tables = [ tbl for tbl in self._metadata.sorted_tables if tbl.name in set(all_table_names) ] tables = [] for table in meta_tables: table_name = table.name column_names = table.columns.keys() table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n" for column_name in column_names: if column_name not in ["ID", "img_name"]: query = text(f"SELECT DISTINCT {column_name} FROM {table_name}") result = connection.execute(query) enum_values: List[str] = [row[0] for row in result.fetchall()] column_info = f"{column_name}: {', '.join(enum_values)}\n" table_info += column_info # table_info = f"Table: {table_name}\n" # # if self._sample_rows_in_table_info: # table_info += f"{self._get_sample_rows(table)}\n" tables.append(table_info) final_str = "\n\n".join(tables) return final_str def run(self, command: str, fetch: str = "all", **kwargs) -> str: """Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned. If the statement returns no rows, an empty string is returned. Args: command: fetch: **kwargs: """ with self._engine.begin() as connection: if self._schema is not None: if self.dialect == "snowflake": connection.exec_driver_sql( f"ALTER SESSION SET search_path='{self._schema}'" ) elif self.dialect == "bigquery": connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'") else: connection.exec_driver_sql(f"SET search_path TO {self._schema}") cursor = connection.execute(text(command)) if cursor.rowcount: if fetch == "all": result = cursor.fetchall() elif fetch == "one": result = cursor.fetchone() # type: ignore else: raise ValueError("Fetch parameter must be either 'one' or 'all'") # Convert columns values to string to avoid issues with sqlalchmey # trunacating text if isinstance(result, list): return json.dumps([r[0] for r in result]) return json.dumps([result[0]]) return ""