机器人 -- 开启用户指引

This commit is contained in:
2024-08-16 15:09:08 +08:00
parent 281f812636
commit 9fa9c620a8

View File

@@ -1,5 +1,4 @@
import json import json
from typing import Dict, Any
from dashscope import Generation from dashscope import Generation
from retry import retry from retry import retry
@@ -8,8 +7,7 @@ from urllib3.exceptions import NewConnectionError
from app.core.config import * from app.core.config import *
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
@@ -22,17 +20,20 @@ get_table_info_description = (
) )
query_database_description = ( query_database_description = (
"The input of this tool is a detailed and correct SQL select query statement, " "The input of this tool is a detailed and correct SQL select query statement, "
"and the output is the result of the database, and it can only return up to 4 results." "and the output is the result of the database, and it can only return up to 4 results."
"If the query is not correct, an error message will be returned." "If the query is not correct, an error message will be returned."
"If an error is returned, rewrite the query, check the query, and try again." "If an error is returned, rewrite the query, check the query, and try again."
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist," "If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
"use get_table_info to query the correct table fields." "use get_table_info to query the correct table fields."
"Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" "Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' " "Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
"order by rand() LIMIT 2'" "order by rand() LIMIT 2'"
) )
tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string")
tools = [ tools = [
# 工具一 # 工具一
@@ -97,6 +98,23 @@ tools = [
}, },
"required": ["sql_string"] "required": ["sql_string"]
} }
},
# 工具四
{
"type": "function",
"function": {
"name": "tutorial_tool",
"description": tutorial_description,
"parameters": {
"type": "object",
"properties": {
"sql_string": {
"type": "string",
"description": "由模型生成的sql语句"
}
}
},
}
} }
] ]
@@ -106,6 +124,7 @@ db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_H
engine_args={"pool_recycle": 7200}) engine_args={"pool_recycle": 7200})
qwen = QWenCallbackHandler() qwen = QWenCallbackHandler()
def search_from_internet(message): def search_from_internet(message):
response = Generation.call( response = Generation.call(
model='qwen-turbo', model='qwen-turbo',
@@ -118,15 +137,19 @@ def search_from_internet(message):
) )
return response return response
def get_database_table(): def get_database_table():
return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear' return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear'
def get_table_info(table_names): def get_table_info(table_names):
return CustomDatabase.get_table_info(db, table_names) return CustomDatabase.get_table_info(db, table_names)
def query_database(sql_string): def query_database(sql_string):
return CustomDatabase.run(db, sql_string) return CustomDatabase.run(db, sql_string)
@retry(exceptions=NewConnectionError, tries=3, delay=1) @retry(exceptions=NewConnectionError, tries=3, delay=1)
def get_response(messages): def get_response(messages):
response = Generation.call( response = Generation.call(
@@ -206,8 +229,12 @@ def call_with_messages(message):
sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string'] sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
tool_info['content'] = query_database(sql_string) tool_info['content'] = query_database(sql_string)
flag = False flag = False
result_content = tool_info['content'] result_content = tool_info['content']
response_type = "image" response_type = "image"
elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool':
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
flag = False
result_content = tool_info['content']
print(f"工具输出信息:{tool_info['content']}\n") print(f"工具输出信息:{tool_info['content']}\n")
messages.append(tool_info) messages.append(tool_info)
@@ -217,8 +244,6 @@ def call_with_messages(message):
final_output["response_type"] = response_type final_output["response_type"] = response_type
QWenCallbackHandler.on_chain_end(qwen, final_output) QWenCallbackHandler.on_chain_end(qwen, final_output)
# 模型的第二轮调用,对工具的输出进行总结 # 模型的第二轮调用,对工具的输出进行总结
# if flag : # if flag :
# second_response = get_response(messages) # second_response = get_response(messages)
@@ -229,9 +254,8 @@ def call_with_messages(message):
return final_output return final_output
def tutorial_tool():
return TUTORIAL_TOOL_RETURN
if __name__ == '__main__': if __name__ == '__main__':
call_with_messages() call_with_messages()