From 9fa9c620a87d7b5c23de9a2cce4fee6779fc7863 Mon Sep 17 00:00:00 2001 From: xupei Date: Fri, 16 Aug 2024 15:09:08 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=BA=E5=99=A8=E4=BA=BA=20--=20=E5=BC=80?= =?UTF-8?q?=E5=90=AF=E7=94=A8=E6=88=B7=E6=8C=87=E5=BC=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat_robot/script/service/CallQWen.py | 64 +++++++++++++------ 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/app/service/chat_robot/script/service/CallQWen.py b/app/service/chat_robot/script/service/CallQWen.py index f8e6bd5..d2e2c06 100644 --- a/app/service/chat_robot/script/service/CallQWen.py +++ b/app/service/chat_robot/script/service/CallQWen.py @@ -1,5 +1,4 @@ import json -from typing import Dict, Any from dashscope import Generation from retry import retry @@ -8,8 +7,7 @@ from urllib3.exceptions import NewConnectionError from app.core.config import * from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.database import CustomDatabase -from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX - +from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." @@ -22,17 +20,20 @@ get_table_info_description = ( ) query_database_description = ( - "The input of this tool is a detailed and correct SQL select query statement, " - "and the output is the result of the database, and it can only return up to 4 results." - "If the query is not correct, an error message will be returned." - "If an error is returned, rewrite the query, check the query, and try again." - "If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist," - "use get_table_info to query the correct table fields." + "The input of this tool is a detailed and correct SQL select query statement, " + "and the output is the result of the database, and it can only return up to 4 results." + "If the query is not correct, an error message will be returned." + "If an error is returned, rewrite the query, check the query, and try again." + "If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist," + "use get_table_info to query the correct table fields." - "Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" - "Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' " - "order by rand() LIMIT 2'" - ) + "Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" + "Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' " + "order by rand() LIMIT 2'" +) + +tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials." + "Input is an empty string") tools = [ # 工具一 @@ -97,6 +98,23 @@ tools = [ }, "required": ["sql_string"] } + }, + # 工具四 + { + "type": "function", + "function": { + "name": "tutorial_tool", + "description": tutorial_description, + "parameters": { + "type": "object", + "properties": { + "sql_string": { + "type": "string", + "description": "由模型生成的sql语句" + } + } + }, + } } ] @@ -106,6 +124,7 @@ db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_H engine_args={"pool_recycle": 7200}) qwen = QWenCallbackHandler() + def search_from_internet(message): response = Generation.call( model='qwen-turbo', @@ -118,15 +137,19 @@ def search_from_internet(message): ) return response + def get_database_table(): return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear' + def get_table_info(table_names): return CustomDatabase.get_table_info(db, table_names) + def query_database(sql_string): return CustomDatabase.run(db, sql_string) + @retry(exceptions=NewConnectionError, tries=3, delay=1) def get_response(messages): response = Generation.call( @@ -206,8 +229,12 @@ def call_with_messages(message): sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string'] tool_info['content'] = query_database(sql_string) flag = False - result_content = tool_info['content'] + result_content = tool_info['content'] response_type = "image" + elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool': + tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()} + flag = False + result_content = tool_info['content'] print(f"工具输出信息:{tool_info['content']}\n") messages.append(tool_info) @@ -217,8 +244,6 @@ def call_with_messages(message): final_output["response_type"] = response_type QWenCallbackHandler.on_chain_end(qwen, final_output) - - # 模型的第二轮调用,对工具的输出进行总结 # if flag : # second_response = get_response(messages) @@ -229,9 +254,8 @@ def call_with_messages(message): return final_output +def tutorial_tool(): + return TUTORIAL_TOOL_RETURN + if __name__ == '__main__': call_with_messages() - - - -