机器人 -- 开启用户指引

This commit is contained in:
2024-08-16 15:09:08 +08:00
parent 281f812636
commit 9fa9c620a8

View File

@@ -1,5 +1,4 @@
import json import json
from typing import Dict, Any
from dashscope import Generation from dashscope import Generation
from retry import retry from retry import retry
@@ -8,8 +7,7 @@ from urllib3.exceptions import NewConnectionError
from app.core.config import * from app.core.config import *
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
@@ -32,7 +30,10 @@ query_database_description = (
"Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" "Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' " "Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
"order by rand() LIMIT 2'" "order by rand() LIMIT 2'"
) )
tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string")
tools = [ tools = [
# 工具一 # 工具一
@@ -97,6 +98,23 @@ tools = [
}, },
"required": ["sql_string"] "required": ["sql_string"]
} }
},
# 工具四
{
"type": "function",
"function": {
"name": "tutorial_tool",
"description": tutorial_description,
"parameters": {
"type": "object",
"properties": {
"sql_string": {
"type": "string",
"description": "由模型生成的sql语句"
}
}
},
}
} }
] ]
@@ -106,6 +124,7 @@ db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_H
engine_args={"pool_recycle": 7200}) engine_args={"pool_recycle": 7200})
qwen = QWenCallbackHandler() qwen = QWenCallbackHandler()
def search_from_internet(message): def search_from_internet(message):
response = Generation.call( response = Generation.call(
model='qwen-turbo', model='qwen-turbo',
@@ -118,15 +137,19 @@ def search_from_internet(message):
) )
return response return response
def get_database_table(): def get_database_table():
return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear' return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear'
def get_table_info(table_names): def get_table_info(table_names):
return CustomDatabase.get_table_info(db, table_names) return CustomDatabase.get_table_info(db, table_names)
def query_database(sql_string): def query_database(sql_string):
return CustomDatabase.run(db, sql_string) return CustomDatabase.run(db, sql_string)
@retry(exceptions=NewConnectionError, tries=3, delay=1) @retry(exceptions=NewConnectionError, tries=3, delay=1)
def get_response(messages): def get_response(messages):
response = Generation.call( response = Generation.call(
@@ -208,6 +231,10 @@ def call_with_messages(message):
flag = False flag = False
result_content = tool_info['content'] result_content = tool_info['content']
response_type = "image" response_type = "image"
elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool':
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
flag = False
result_content = tool_info['content']
print(f"工具输出信息:{tool_info['content']}\n") print(f"工具输出信息:{tool_info['content']}\n")
messages.append(tool_info) messages.append(tool_info)
@@ -217,8 +244,6 @@ def call_with_messages(message):
final_output["response_type"] = response_type final_output["response_type"] = response_type
QWenCallbackHandler.on_chain_end(qwen, final_output) QWenCallbackHandler.on_chain_end(qwen, final_output)
# 模型的第二轮调用,对工具的输出进行总结 # 模型的第二轮调用,对工具的输出进行总结
# if flag : # if flag :
# second_response = get_response(messages) # second_response = get_response(messages)
@@ -229,9 +254,8 @@ def call_with_messages(message):
return final_output return final_output
def tutorial_tool():
return TUTORIAL_TOOL_RETURN
if __name__ == '__main__': if __name__ == '__main__':
call_with_messages() call_with_messages()