184 lines
7.1 KiB
Python
184 lines
7.1 KiB
Python
# flake8: noqa
|
||
"""Tools for interacting with a SQL database."""
|
||
from typing import Any, Dict, Optional
|
||
|
||
from pydantic import BaseModel, Extra, Field, root_validator
|
||
|
||
from langchain.base_language import BaseLanguageModel
|
||
from langchain.callbacks.manager import (
|
||
AsyncCallbackManagerForToolRun,
|
||
CallbackManagerForToolRun,
|
||
)
|
||
from langchain.chains.llm import LLMChain
|
||
from langchain.prompts import PromptTemplate
|
||
# from langchain.sql_database import SQLDatabase
|
||
from langchain.utilities import SQLDatabase
|
||
from langchain.tools.base import BaseTool
|
||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||
|
||
|
||
class BaseSQLDatabaseTool(BaseModel):
|
||
"""Base tools for interacting with a SQL database."""
|
||
|
||
db: SQLDatabase = Field(exclude=True)
|
||
param_description: str = ""
|
||
|
||
# Override BaseTool.Config to appease mypy
|
||
# See https://github.com/pydantic/pydantic/issues/4173
|
||
class Config(BaseTool.Config):
|
||
"""Configuration for this pydantic object."""
|
||
|
||
arbitrary_types_allowed = True
|
||
extra = Extra.forbid
|
||
|
||
|
||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||
"""Tool for querying a SQL database."""
|
||
|
||
name = "sql_db_query"
|
||
# description = """
|
||
# Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables.
|
||
# This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database.
|
||
# You should always use ‘order by rand()’ to randomly select data.
|
||
# If the query is not correct, an error message will be returned.
|
||
# If an error is returned, rewrite the query, check the query, and try again.
|
||
# Always limit your query to at most 4 results.
|
||
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
||
# You MUST double check your query before executing it.
|
||
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||
# """
|
||
|
||
description: str = (
|
||
"The input of this tool is a detailed and correct SQL select query statement, "
|
||
"and the output is the result of the database, and it can only return up to 4 results."
|
||
"If the query is not correct, an error message will be returned."
|
||
"If an error is returned, rewrite the query, check the query, and try again."
|
||
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
|
||
"use sql_db_schema to query the correct table fields."
|
||
|
||
"Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() "
|
||
"LIMIT 1'"
|
||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||
"order by rand() LIMIT 2'"
|
||
)
|
||
|
||
def _run(
|
||
self,
|
||
query: str,
|
||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||
) -> str:
|
||
"""Execute the query, return the results or an error message."""
|
||
result = self.db.run_no_throw(query)
|
||
return result
|
||
|
||
async def _arun(
|
||
self,
|
||
query: str,
|
||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||
) -> str:
|
||
raise NotImplementedError("QuerySqlDbTool does not support async")
|
||
|
||
|
||
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||
"""Tool for getting metadata about a SQL database."""
|
||
|
||
name = "sql_db_schema"
|
||
# description = """
|
||
# The database contains information of lots of fashion items, such as item name, their fashion attributes.
|
||
# There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear.
|
||
# Find the most relevant tables with the query, and output the schema of these tables.
|
||
# """
|
||
|
||
description: str = (
|
||
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
|
||
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
|
||
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
||
|
||
"Example Input: 'female_outwear, male_top'"
|
||
)
|
||
|
||
def _run(
|
||
self,
|
||
table_names: str,
|
||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||
) -> str:
|
||
"""Get the schema for tables in a comma-separated list."""
|
||
return self.db.get_table_info_no_throw(table_names.split(", "))
|
||
|
||
async def _arun(
|
||
self,
|
||
table_name: str,
|
||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||
) -> str:
|
||
raise NotImplementedError("SchemaSqlDbTool does not support async")
|
||
|
||
|
||
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||
"""Tool for getting tables names."""
|
||
|
||
name = "sql_db_list_tables"
|
||
description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||
|
||
def _run(
|
||
self,
|
||
tool_input: str = "",
|
||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||
) -> str:
|
||
"""Get the schema for a specific table."""
|
||
return ", ".join(self.db.get_usable_table_names())
|
||
|
||
async def _arun(
|
||
self,
|
||
tool_input: str = "",
|
||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||
) -> str:
|
||
raise NotImplementedError("ListTablesSqlDbTool does not support async")
|
||
|
||
|
||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||
"""Use an LLM to check if a query is correct.
|
||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||
|
||
template: str = QUERY_CHECKER
|
||
llm: BaseLanguageModel
|
||
llm_chain: LLMChain = Field(init=False)
|
||
name = "sql_db_query_checker"
|
||
description = (
|
||
"Use this tools to double check if your query is correct before executing it."
|
||
"Always use this tools before executing a query with sql_db_query!"
|
||
)
|
||
|
||
@root_validator(pre=True)
|
||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||
if "llm_chain" not in values:
|
||
values["llm_chain"] = LLMChain(
|
||
llm=values.get("llm"),
|
||
prompt=PromptTemplate(
|
||
template=QUERY_CHECKER,
|
||
input_variables=["query", "dialect"]
|
||
),
|
||
)
|
||
|
||
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
||
raise ValueError(
|
||
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||
)
|
||
|
||
return values
|
||
|
||
def _run(
|
||
self,
|
||
query: str,
|
||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||
) -> str:
|
||
"""Use the LLM to check the query."""
|
||
return self.llm_chain.predict(query=query, dialect=self.db.dialect)
|
||
|
||
async def _arun(
|
||
self,
|
||
query: str,
|
||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||
) -> str:
|
||
return await self.llm_chain.apredict(query=query, dialect=self.db.dialect)
|