机器人 -- 开启用户指引
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
from dashscope import Generation
|
from dashscope import Generation
|
||||||
from retry import retry
|
from retry import retry
|
||||||
@@ -8,8 +7,7 @@ from urllib3.exceptions import NewConnectionError
|
|||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||||
from app.service.chat_robot.script.database import CustomDatabase
|
from app.service.chat_robot.script.database import CustomDatabase
|
||||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
|
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
|
||||||
|
|
||||||
|
|
||||||
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||||
|
|
||||||
@@ -22,17 +20,20 @@ get_table_info_description = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
query_database_description = (
|
query_database_description = (
|
||||||
"The input of this tool is a detailed and correct SQL select query statement, "
|
"The input of this tool is a detailed and correct SQL select query statement, "
|
||||||
"and the output is the result of the database, and it can only return up to 4 results."
|
"and the output is the result of the database, and it can only return up to 4 results."
|
||||||
"If the query is not correct, an error message will be returned."
|
"If the query is not correct, an error message will be returned."
|
||||||
"If an error is returned, rewrite the query, check the query, and try again."
|
"If an error is returned, rewrite the query, check the query, and try again."
|
||||||
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
|
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
|
||||||
"use get_table_info to query the correct table fields."
|
"use get_table_info to query the correct table fields."
|
||||||
|
|
||||||
"Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
|
"Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
|
||||||
"Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
"Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||||||
"order by rand() LIMIT 2'"
|
"order by rand() LIMIT 2'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||||
|
"Input is an empty string")
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
# 工具一
|
# 工具一
|
||||||
@@ -97,6 +98,23 @@ tools = [
|
|||||||
},
|
},
|
||||||
"required": ["sql_string"]
|
"required": ["sql_string"]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
# 工具四
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "tutorial_tool",
|
||||||
|
"description": tutorial_description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sql_string": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "由模型生成的sql语句"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -106,6 +124,7 @@ db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_H
|
|||||||
engine_args={"pool_recycle": 7200})
|
engine_args={"pool_recycle": 7200})
|
||||||
qwen = QWenCallbackHandler()
|
qwen = QWenCallbackHandler()
|
||||||
|
|
||||||
|
|
||||||
def search_from_internet(message):
|
def search_from_internet(message):
|
||||||
response = Generation.call(
|
response = Generation.call(
|
||||||
model='qwen-turbo',
|
model='qwen-turbo',
|
||||||
@@ -118,15 +137,19 @@ def search_from_internet(message):
|
|||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def get_database_table():
|
def get_database_table():
|
||||||
return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear'
|
return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear'
|
||||||
|
|
||||||
|
|
||||||
def get_table_info(table_names):
|
def get_table_info(table_names):
|
||||||
return CustomDatabase.get_table_info(db, table_names)
|
return CustomDatabase.get_table_info(db, table_names)
|
||||||
|
|
||||||
|
|
||||||
def query_database(sql_string):
|
def query_database(sql_string):
|
||||||
return CustomDatabase.run(db, sql_string)
|
return CustomDatabase.run(db, sql_string)
|
||||||
|
|
||||||
|
|
||||||
@retry(exceptions=NewConnectionError, tries=3, delay=1)
|
@retry(exceptions=NewConnectionError, tries=3, delay=1)
|
||||||
def get_response(messages):
|
def get_response(messages):
|
||||||
response = Generation.call(
|
response = Generation.call(
|
||||||
@@ -206,8 +229,12 @@ def call_with_messages(message):
|
|||||||
sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
|
sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
|
||||||
tool_info['content'] = query_database(sql_string)
|
tool_info['content'] = query_database(sql_string)
|
||||||
flag = False
|
flag = False
|
||||||
result_content = tool_info['content']
|
result_content = tool_info['content']
|
||||||
response_type = "image"
|
response_type = "image"
|
||||||
|
elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool':
|
||||||
|
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
|
||||||
|
flag = False
|
||||||
|
result_content = tool_info['content']
|
||||||
|
|
||||||
print(f"工具输出信息:{tool_info['content']}\n")
|
print(f"工具输出信息:{tool_info['content']}\n")
|
||||||
messages.append(tool_info)
|
messages.append(tool_info)
|
||||||
@@ -217,8 +244,6 @@ def call_with_messages(message):
|
|||||||
final_output["response_type"] = response_type
|
final_output["response_type"] = response_type
|
||||||
QWenCallbackHandler.on_chain_end(qwen, final_output)
|
QWenCallbackHandler.on_chain_end(qwen, final_output)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 模型的第二轮调用,对工具的输出进行总结
|
# 模型的第二轮调用,对工具的输出进行总结
|
||||||
# if flag :
|
# if flag :
|
||||||
# second_response = get_response(messages)
|
# second_response = get_response(messages)
|
||||||
@@ -229,9 +254,8 @@ def call_with_messages(message):
|
|||||||
return final_output
|
return final_output
|
||||||
|
|
||||||
|
|
||||||
|
def tutorial_tool():
|
||||||
|
return TUTORIAL_TOOL_RETURN
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
call_with_messages()
|
call_with_messages()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user