feat chat robot 接口迁移

This commit is contained in:
zhouchengrong
2024-05-29 11:12:59 +08:00
parent a9dcd444c8
commit 13fec64125
23 changed files with 1139 additions and 1 deletions

View File

@@ -0,0 +1,183 @@
# flake8: noqa
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
# from langchain.sql_database import SQLDatabase
from langchain.utilities import SQLDatabase
from langchain.tools.base import BaseTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
class BaseSQLDatabaseTool(BaseModel):
"""Base tools for interacting with a SQL database."""
db: SQLDatabase = Field(exclude=True)
param_description: str = ""
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
extra = Extra.forbid
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""
name = "sql_db_query"
# description = """
# Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables.
# This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database.
# You should always use order by rand() to randomly select data.
# If the query is not correct, an error message will be returned.
# If an error is returned, rewrite the query, check the query, and try again.
# Always limit your query to at most 4 results.
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
# You MUST double check your query before executing it.
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
# """
description: str = (
"The input of this tool is a detailed and correct SQL select query statement, "
"and the output is the result of the database, and it can only return up to 4 results."
"If the query is not correct, an error message will be returned."
"If an error is returned, rewrite the query, check the query, and try again."
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
"use sql_db_schema to query the correct table fields."
"Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() "
"LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
"order by rand() LIMIT 2'"
)
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Execute the query, return the results or an error message."""
result = self.db.run_no_throw(query)
return result
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("QuerySqlDbTool does not support async")
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting metadata about a SQL database."""
name = "sql_db_schema"
# description = """
# The database contains information of lots of fashion items, such as item name, their fashion attributes.
# There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear.
# Find the most relevant tables with the query, and output the schema of these tables.
# """
description: str = (
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
"Example Input: 'female_outwear, male_top'"
)
def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
return self.db.get_table_info_no_throw(table_names.split(", "))
async def _arun(
self,
table_name: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("SchemaSqlDbTool does not support async")
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""
name = "sql_db_list_tables"
description = "Input is an empty string, output is a comma separated list of tables in the database."
def _run(
self,
tool_input: str = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for a specific table."""
return ", ".join(self.db.get_usable_table_names())
async def _arun(
self,
tool_input: str = "",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("ListTablesSqlDbTool does not support async")
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False)
name = "sql_db_query_checker"
description = (
"Use this tools to double check if your query is correct before executing it."
"Always use this tools before executing a query with sql_db_query!"
)
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER,
input_variables=["query", "dialect"]
),
)
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
)
return values
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
return self.llm_chain.predict(query=query, dialect=self.db.dialect)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.llm_chain.apredict(query=query, dialect=self.db.dialect)