2024-05-29 11:12:59 +08:00
|
|
|
|
# flake8: noqa
|
|
|
|
|
|
"""Tools for interacting with a SQL database."""
|
2024-07-08 18:50:01 +08:00
|
|
|
|
from typing import Any, Dict, Optional, Type
|
2024-05-29 11:12:59 +08:00
|
|
|
|
|
2025-12-30 16:49:08 +08:00
|
|
|
|
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
|
|
|
|
|
from langchain_community.tools.sql_database.tool import _QuerySQLCheckerToolInput
|
2024-05-29 11:12:59 +08:00
|
|
|
|
# from langchain.sql_database import SQLDatabase
|
2024-07-08 18:50:01 +08:00
|
|
|
|
from langchain_community.utilities import SQLDatabase
|
2025-12-30 16:49:08 +08:00
|
|
|
|
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
|
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
|
from langchain_core.prompts import PromptTemplate
|
|
|
|
|
|
from langchain_core.tools import BaseTool
|
|
|
|
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
2024-05-29 11:12:59 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseSQLDatabaseTool(BaseModel):
|
|
|
|
|
|
"""Base tools for interacting with a SQL database."""
|
|
|
|
|
|
|
|
|
|
|
|
db: SQLDatabase = Field(exclude=True)
|
|
|
|
|
|
param_description: str = ""
|
|
|
|
|
|
|
|
|
|
|
|
# Override BaseTool.Config to appease mypy
|
|
|
|
|
|
# See https://github.com/pydantic/pydantic/issues/4173
|
|
|
|
|
|
class Config(BaseTool.Config):
|
|
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
|
|
|
|
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
extra = Extra.forbid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
|
|
|
|
|
"""Tool for querying a SQL database."""
|
|
|
|
|
|
|
|
|
|
|
|
name = "sql_db_query"
|
|
|
|
|
|
# description = """
|
|
|
|
|
|
# Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables.
|
|
|
|
|
|
# This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database.
|
|
|
|
|
|
# You should always use ‘order by rand()’ to randomly select data.
|
|
|
|
|
|
# If the query is not correct, an error message will be returned.
|
|
|
|
|
|
# If an error is returned, rewrite the query, check the query, and try again.
|
|
|
|
|
|
# Always limit your query to at most 4 results.
|
|
|
|
|
|
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
|
|
|
|
|
# You MUST double check your query before executing it.
|
|
|
|
|
|
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
|
|
|
|
|
# """
|
|
|
|
|
|
|
|
|
|
|
|
description: str = (
|
|
|
|
|
|
"The input of this tool is a detailed and correct SQL select query statement, "
|
|
|
|
|
|
"and the output is the result of the database, and it can only return up to 4 results."
|
|
|
|
|
|
"If the query is not correct, an error message will be returned."
|
|
|
|
|
|
"If an error is returned, rewrite the query, check the query, and try again."
|
|
|
|
|
|
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
|
|
|
|
|
|
"use sql_db_schema to query the correct table fields."
|
|
|
|
|
|
|
|
|
|
|
|
"Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() "
|
|
|
|
|
|
"LIMIT 1'"
|
|
|
|
|
|
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
|
|
|
|
|
"order by rand() LIMIT 2'"
|
2025-12-30 16:49:08 +08:00
|
|
|
|
)
|
2024-05-29 11:12:59 +08:00
|
|
|
|
|
|
|
|
|
|
def _run(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""Execute the query, return the results or an error message."""
|
|
|
|
|
|
result = self.db.run_no_throw(query)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
async def _arun(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
raise NotImplementedError("QuerySqlDbTool does not support async")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
|
|
|
|
|
"""Tool for getting metadata about a SQL database."""
|
|
|
|
|
|
|
|
|
|
|
|
name = "sql_db_schema"
|
|
|
|
|
|
# description = """
|
|
|
|
|
|
# The database contains information of lots of fashion items, such as item name, their fashion attributes.
|
|
|
|
|
|
# There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear.
|
|
|
|
|
|
# Find the most relevant tables with the query, and output the schema of these tables.
|
|
|
|
|
|
# """
|
|
|
|
|
|
|
|
|
|
|
|
description: str = (
|
|
|
|
|
|
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
|
|
|
|
|
|
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
|
|
|
|
|
|
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
2025-12-30 16:49:08 +08:00
|
|
|
|
|
2024-05-29 11:12:59 +08:00
|
|
|
|
"Example Input: 'female_outwear, male_top'"
|
2025-12-30 16:49:08 +08:00
|
|
|
|
)
|
2024-05-29 11:12:59 +08:00
|
|
|
|
|
|
|
|
|
|
def _run(
|
|
|
|
|
|
self,
|
|
|
|
|
|
table_names: str,
|
|
|
|
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""Get the schema for tables in a comma-separated list."""
|
|
|
|
|
|
return self.db.get_table_info_no_throw(table_names.split(", "))
|
|
|
|
|
|
|
|
|
|
|
|
async def _arun(
|
|
|
|
|
|
self,
|
|
|
|
|
|
table_name: str,
|
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
raise NotImplementedError("SchemaSqlDbTool does not support async")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
|
|
|
|
|
"""Tool for getting tables names."""
|
|
|
|
|
|
|
|
|
|
|
|
name = "sql_db_list_tables"
|
|
|
|
|
|
description = "Input is an empty string, output is a comma separated list of tables in the database."
|
|
|
|
|
|
|
|
|
|
|
|
def _run(
|
|
|
|
|
|
self,
|
|
|
|
|
|
tool_input: str = "",
|
|
|
|
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""Get the schema for a specific table."""
|
|
|
|
|
|
return ", ".join(self.db.get_usable_table_names())
|
|
|
|
|
|
|
|
|
|
|
|
async def _arun(
|
|
|
|
|
|
self,
|
|
|
|
|
|
tool_input: str = "",
|
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
raise NotImplementedError("ListTablesSqlDbTool does not support async")
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-07-08 18:50:01 +08:00
|
|
|
|
# class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
|
|
|
|
|
# """Use an LLM to check if a query is correct.
|
|
|
|
|
|
# Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
|
|
|
|
|
#
|
|
|
|
|
|
# template: str = QUERY_CHECKER
|
|
|
|
|
|
# llm: BaseLanguageModel
|
|
|
|
|
|
# llm_chain: LLMChain = Field(init=False)
|
|
|
|
|
|
# name = "sql_db_query_checker"
|
|
|
|
|
|
# description = (
|
|
|
|
|
|
# "Use this tools to double check if your query is correct before executing it."
|
|
|
|
|
|
# "Always use this tools before executing a query with sql_db_query!"
|
|
|
|
|
|
# )
|
|
|
|
|
|
#
|
|
|
|
|
|
# @root_validator(pre=True)
|
|
|
|
|
|
# def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
# if "llm_chain" not in values:
|
|
|
|
|
|
# values["llm_chain"] = LLMChain(
|
|
|
|
|
|
# llm=values.get("llm"),
|
|
|
|
|
|
# prompt=PromptTemplate(
|
|
|
|
|
|
# template=QUERY_CHECKER,
|
|
|
|
|
|
# input_variables=["query", "dialect"]
|
|
|
|
|
|
# ),
|
|
|
|
|
|
# )
|
|
|
|
|
|
#
|
|
|
|
|
|
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
|
|
|
|
|
# # if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
|
|
|
|
|
# raise ValueError(
|
|
|
|
|
|
# "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
|
|
|
|
|
# )
|
|
|
|
|
|
#
|
|
|
|
|
|
# return values
|
2024-05-29 11:12:59 +08:00
|
|
|
|
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
|
|
|
|
|
"""Use an LLM to check if a query is correct.
|
|
|
|
|
|
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
|
|
|
|
|
|
|
|
|
|
|
template: str = QUERY_CHECKER
|
|
|
|
|
|
llm: BaseLanguageModel
|
2024-07-08 18:50:01 +08:00
|
|
|
|
llm_chain: Any = Field(init=False)
|
|
|
|
|
|
name: str = "sql_db_query_checker"
|
|
|
|
|
|
description: str = """
|
|
|
|
|
|
Use this tool to double check if your query is correct before executing it.
|
|
|
|
|
|
Always use this tool before executing a query with sql_db_query!
|
|
|
|
|
|
"""
|
|
|
|
|
|
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
|
2024-05-29 11:12:59 +08:00
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=True)
|
2025-12-30 16:49:08 +08:00
|
|
|
|
def initialize_llm_chain(self, values: Dict[str, Any]) -> Dict[str, Any]:
|
2024-05-29 11:12:59 +08:00
|
|
|
|
if "llm_chain" not in values:
|
2024-07-08 18:50:01 +08:00
|
|
|
|
# from langchain.chains.llm import LLMChain
|
2024-05-29 11:12:59 +08:00
|
|
|
|
|
2025-12-30 16:49:08 +08:00
|
|
|
|
llm = values.get("llm") # type: ignore[arg-type]
|
2024-07-08 18:50:01 +08:00
|
|
|
|
prompt = PromptTemplate(
|
|
|
|
|
|
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
|
|
|
|
|
)
|
|
|
|
|
|
values["llm_chain"] = prompt | llm
|
|
|
|
|
|
# values["llm_chain"] = LLMChain(
|
|
|
|
|
|
# llm=values.get("llm"), # type: ignore[arg-type]
|
|
|
|
|
|
# prompt=PromptTemplate(
|
|
|
|
|
|
# template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
|
|
|
|
|
# ),
|
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
|
|
# if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
|
|
|
|
|
if values["llm_chain"].first.input_variables != ["dialect", "query"]:
|
2024-05-29 11:12:59 +08:00
|
|
|
|
raise ValueError(
|
|
|
|
|
|
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
|
|
def _run(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""Use the LLM to check the query."""
|
|
|
|
|
|
return self.llm_chain.predict(query=query, dialect=self.db.dialect)
|
|
|
|
|
|
|
|
|
|
|
|
async def _arun(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
return await self.llm_chain.apredict(query=query, dialect=self.db.dialect)
|