feat chat robot 接口迁移
This commit is contained in:
10
app/service/chat_robot/script/tools/__init__.py
Normal file
10
app/service/chat_robot/script/tools/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .sql_tools import (
|
||||
QuerySQLDataBaseTool,
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"QuerySQLCheckerTool", "InfoSQLDatabaseTool", "ListSQLDatabaseTool", "QuerySQLDataBaseTool"
|
||||
]
|
||||
183
app/service/chat_robot/script/tools/sql_tools.py
Normal file
183
app/service/chat_robot/script/tools/sql_tools.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# flake8: noqa
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
# from langchain.sql_database import SQLDatabase
|
||||
from langchain.utilities import SQLDatabase
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
|
||||
class BaseSQLDatabaseTool(BaseModel):
|
||||
"""Base tools for interacting with a SQL database."""
|
||||
|
||||
db: SQLDatabase = Field(exclude=True)
|
||||
param_description: str = ""
|
||||
|
||||
# Override BaseTool.Config to appease mypy
|
||||
# See https://github.com/pydantic/pydantic/issues/4173
|
||||
class Config(BaseTool.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for querying a SQL database."""
|
||||
|
||||
name = "sql_db_query"
|
||||
# description = """
|
||||
# Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables.
|
||||
# This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database.
|
||||
# You should always use ‘order by rand()’ to randomly select data.
|
||||
# If the query is not correct, an error message will be returned.
|
||||
# If an error is returned, rewrite the query, check the query, and try again.
|
||||
# Always limit your query to at most 4 results.
|
||||
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
||||
# You MUST double check your query before executing it.
|
||||
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||
# """
|
||||
|
||||
description: str = (
|
||||
"The input of this tool is a detailed and correct SQL select query statement, "
|
||||
"and the output is the result of the database, and it can only return up to 4 results."
|
||||
"If the query is not correct, an error message will be returned."
|
||||
"If an error is returned, rewrite the query, check the query, and try again."
|
||||
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
|
||||
"use sql_db_schema to query the correct table fields."
|
||||
|
||||
"Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() "
|
||||
"LIMIT 1'"
|
||||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||||
"order by rand() LIMIT 2'"
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
result = self.db.run_no_throw(query)
|
||||
return result
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError("QuerySqlDbTool does not support async")
|
||||
|
||||
|
||||
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for getting metadata about a SQL database."""
|
||||
|
||||
name = "sql_db_schema"
|
||||
# description = """
|
||||
# The database contains information of lots of fashion items, such as item name, their fashion attributes.
|
||||
# There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear.
|
||||
# Find the most relevant tables with the query, and output the schema of these tables.
|
||||
# """
|
||||
|
||||
description: str = (
|
||||
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
|
||||
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
|
||||
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
||||
|
||||
"Example Input: 'female_outwear, male_top'"
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
table_names: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for tables in a comma-separated list."""
|
||||
return self.db.get_table_info_no_throw(table_names.split(", "))
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
table_name: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError("SchemaSqlDbTool does not support async")
|
||||
|
||||
|
||||
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for getting tables names."""
|
||||
|
||||
name = "sql_db_list_tables"
|
||||
description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||
|
||||
def _run(
|
||||
self,
|
||||
tool_input: str = "",
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for a specific table."""
|
||||
return ", ".join(self.db.get_usable_table_names())
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
tool_input: str = "",
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError("ListTablesSqlDbTool does not support async")
|
||||
|
||||
|
||||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Use an LLM to check if a query is correct.
|
||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
|
||||
template: str = QUERY_CHECKER
|
||||
llm: BaseLanguageModel
|
||||
llm_chain: LLMChain = Field(init=False)
|
||||
name = "sql_db_query_checker"
|
||||
description = (
|
||||
"Use this tools to double check if your query is correct before executing it."
|
||||
"Always use this tools before executing a query with sql_db_query!"
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
values["llm_chain"] = LLMChain(
|
||||
llm=values.get("llm"),
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER,
|
||||
input_variables=["query", "dialect"]
|
||||
),
|
||||
)
|
||||
|
||||
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
||||
raise ValueError(
|
||||
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the LLM to check the query."""
|
||||
return self.llm_chain.predict(query=query, dialect=self.db.dialect)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
return await self.llm_chain.apredict(query=query, dialect=self.db.dialect)
|
||||
19
app/service/chat_robot/script/tools/tutorial_tool.py
Normal file
19
app/service/chat_robot/script/tools/tutorial_tool.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN
|
||||
|
||||
|
||||
# 处理系统引导教程相关的输入
|
||||
class CustomTutorialTool(BaseTool):
|
||||
name = "tutorial_tool"
|
||||
|
||||
description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
"Input is an empty string")
|
||||
|
||||
def _run(self, tool_input, **kwargs: Any) -> str:
|
||||
return TUTORIAL_TOOL_RETURN
|
||||
|
||||
async def _arun(self, tool_input, **kwargs: Any) -> str:
|
||||
raise NotImplementedError("CustomTutorialTool does not support async")
|
||||
Reference in New Issue
Block a user