357 lines
14 KiB
Python
357 lines
14 KiB
Python
import json
|
||
import logging
|
||
|
||
from dashscope import Generation
|
||
from retry import retry
|
||
from urllib3.exceptions import NewConnectionError
|
||
|
||
from app.core.config import *
|
||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||
from app.service.chat_robot.script.database import CustomDatabase
|
||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
|
||
from app.service.search_image_with_text.service import query
|
||
|
||
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||
|
||
get_table_info_description = (
|
||
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
|
||
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
|
||
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
||
|
||
"Example Input: 'female_outwear, male_top'"
|
||
)
|
||
|
||
query_database_description = (
|
||
"The input of this tool is a detailed and correct SQL select query statement, "
|
||
"and the output is the result of the database, and it can only return up to 4 results."
|
||
"If the query is not correct, an error message will be returned."
|
||
"If an error is returned, rewrite the query, check the query, and try again."
|
||
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
|
||
"use get_table_info to query the correct table fields."
|
||
|
||
"Example Input: 'SELECT img_name FROM female_skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
|
||
"Example Input 2: 'SELECT img_name FROM female_top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||
"order by rand() LIMIT 2'"
|
||
)
|
||
|
||
query_vector_db_description = (
|
||
"Use this tool to find the clothing images that users need. "
|
||
"If the user's input includes clothing types such as blouse, skirt, dress, outerwear, pants, or trousers, please use this tool. "
|
||
"The input for the tool is the string provided by the user."
|
||
)
|
||
|
||
tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||
"Input is an empty string")
|
||
|
||
tools = [
|
||
# 工具一
|
||
# {
|
||
# "type": "function",
|
||
# "function": {
|
||
# "name": "search_from_internet",
|
||
# "description": "从网络搜索结果。",
|
||
# "parameters": {
|
||
# "type" : "object",
|
||
# "properties" : {
|
||
# "user_input" : {
|
||
# "type" : "string",
|
||
# "description" : "用户输入。比如 : 2025年的时尚潮流趋势是什么?"
|
||
# }
|
||
# }
|
||
# }
|
||
# }
|
||
# },
|
||
# 工具二
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "get_database_table",
|
||
"description": get_database_table_description,
|
||
"parameters": {
|
||
}
|
||
}
|
||
},
|
||
# 工具三
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "get_table_info",
|
||
"description": get_table_info_description,
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"table_names": {
|
||
"type": "list",
|
||
"description": "需要查询表结构的表名"
|
||
}
|
||
}
|
||
},
|
||
"required": ["table_names"]
|
||
}
|
||
},
|
||
# 工具四
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "query_database",
|
||
"description": query_database_description,
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"sql_string": {
|
||
"type": "string",
|
||
"description": "由模型生成的sql语句"
|
||
}
|
||
}
|
||
},
|
||
"required": ["sql_string"]
|
||
}
|
||
},
|
||
# 工具四
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "tutorial_tool",
|
||
"description": tutorial_description,
|
||
# "parameters": {
|
||
# "type": "object",
|
||
# "properties": {
|
||
# "sql_string": {
|
||
# "type": "string",
|
||
# "description": "由模型生成的sql语句"
|
||
# }
|
||
# }
|
||
# },
|
||
}
|
||
},
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "get_image_from_vector_db",
|
||
"description": query_vector_db_description,
|
||
"parameters": {
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"gender": {
|
||
"type": "string",
|
||
"description": "性别"
|
||
},
|
||
"content": {
|
||
"type": "string",
|
||
"description": "用户描述"
|
||
}
|
||
}
|
||
},
|
||
}
|
||
}
|
||
}
|
||
]
|
||
|
||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||
engine_args={"pool_recycle": 7200})
|
||
qwen = QWenCallbackHandler()
|
||
|
||
|
||
def search_from_internet(message):
|
||
response = Generation.call(
|
||
model='qwen-turbo',
|
||
api_key=QWEN_API_KEY,
|
||
messages=message,
|
||
prompt='The output must be in English.Keep the final result under 200 words.'
|
||
# tools=tools,
|
||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||
# result_format='message', # 将输出设置为message形式
|
||
# enable_search='True'
|
||
)
|
||
return response
|
||
|
||
|
||
def get_database_table():
|
||
return 'female_top, female_skirt, female_pants, female_dress, female_outwear, male_bottom, male_top, male_outwear'
|
||
|
||
|
||
def get_table_info(table_names):
|
||
return CustomDatabase.get_table_info(db, table_names)
|
||
|
||
|
||
def query_database(sql_string):
|
||
return CustomDatabase.run(db, sql_string)
|
||
|
||
|
||
def get_image_from_vector_db(gender, content):
|
||
return query(gender, content)
|
||
|
||
|
||
@retry(exceptions=NewConnectionError, tries=3, delay=1)
|
||
def get_response(messages):
|
||
response = Generation.call(
|
||
model='qwen-max',
|
||
api_key=QWEN_API_KEY,
|
||
messages=messages,
|
||
tools=tools,
|
||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||
result_format='message', # 将输出设置为message形式
|
||
enable_search='True'
|
||
)
|
||
return response
|
||
|
||
|
||
def get_assistant_response(messages):
|
||
response = Generation.call(
|
||
model='qwen-max',
|
||
api_key=QWEN_API_KEY,
|
||
messages=messages,
|
||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||
result_format='message', # 将输出设置为message形式
|
||
enable_search='false'
|
||
)
|
||
return response
|
||
|
||
|
||
def call_with_messages(message):
|
||
global tool_info
|
||
user_input = message
|
||
print('\n')
|
||
|
||
messages = [
|
||
{
|
||
# "content": FASHION_CHAT_BOT_PREFIX, # 系统message
|
||
"content": FASHION_CHAT_BOT_PREFIX_TEMP, # 修改后的系统message
|
||
"role": "system"
|
||
},
|
||
{
|
||
# "content": input('请输入:'), # 用户message
|
||
"content": message, # 用户message
|
||
"role": "user"
|
||
},
|
||
{
|
||
"content": TOOLS_FUNCTIONS_SUFFIX, # ai message
|
||
"role": "assistant"
|
||
}
|
||
]
|
||
|
||
flag = True
|
||
count = 1
|
||
# result_content = "我是一个时尚AI助手,请问有什么可以帮您"
|
||
result_content = "I am a fashion AI assistant, how can I help you?"
|
||
response_type = "chat"
|
||
|
||
while flag and count <= 3:
|
||
first_response = get_response(messages)
|
||
assistant_output = first_response.output.choices[0].message
|
||
QWenCallbackHandler.on_llm_end(qwen, first_response.usage)
|
||
print(f"\n大模型第 {count} 轮输出信息:{first_response}\n")
|
||
messages.append(assistant_output)
|
||
|
||
if 'tool_calls' not in assistant_output: # 如果模型判断无需调用工具,则将assistant的回复直接打印出来,无需进行模型的第二轮调用
|
||
print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容
|
||
result_content = assistant_output.content
|
||
break
|
||
# 如果模型选择的工具是internet_search
|
||
elif assistant_output.tool_calls[0]['function']['name'] == 'internet_search':
|
||
tool_info = {"name": "search_from_internet", "role": "tool"}
|
||
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
|
||
message = [
|
||
{'role': 'assistant', 'content': content['query'] if "query" in content.keys() else user_input}
|
||
]
|
||
tool_info['content'] = search_from_internet(message)
|
||
flag = False
|
||
result_content = tool_info['content'].output.text
|
||
# 如果模型选择的工具是get_database_table
|
||
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
|
||
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
|
||
# # 如果模型选择的工具是get_table_info
|
||
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
|
||
# tool_info = {"name": "get_table_info", "role": "tool"}
|
||
# table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
|
||
# tool_info['content'] = get_table_info(table_names)
|
||
# # 如果模型选择的工具是query_database
|
||
# elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
|
||
# tool_info = {"name": "query_database", "role": "tool"}
|
||
# sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
|
||
# tool_info['content'] = query_database(sql_string)
|
||
# flag = False
|
||
# result_content = tool_info['content']
|
||
# response_type = "image"
|
||
elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool':
|
||
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
|
||
flag = False
|
||
result_content = tool_info['content']
|
||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
|
||
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
|
||
# todo 从历史对话中获取性别,目前无法获得性别时,默认使用female
|
||
gender = content['gender'] if "gender" in content.keys() and content['gender'] != 'unisex' else 'female'
|
||
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
|
||
'content': get_image_from_vector_db(gender, content['parameters']['content'] if "parameters" in content.keys() else content['content'])}
|
||
flag = False
|
||
result_content = tool_info['content']
|
||
response_type = "image"
|
||
else:
|
||
tool_info = {"name": assistant_output.tool_calls[0]['function']['name'], 'content': 'null'}
|
||
logging.info(assistant_output.tool_calls[0]['function']['name'] + "(unknown tools)")
|
||
flag = False
|
||
|
||
print(f"工具输出信息:{tool_info['content']}\n")
|
||
messages.append(tool_info)
|
||
count += 1
|
||
|
||
final_output = {"output": result_content, "response_type": response_type}
|
||
QWenCallbackHandler.on_chain_end(qwen, final_output)
|
||
|
||
# 模型的第二轮调用,对工具的输出进行总结
|
||
# if flag :
|
||
# second_response = get_response(messages)
|
||
# print(f"大模型第二轮输出信息:{second_response}\n")
|
||
# print(f"最终答案:{second_response.output.choices[0].message['content']}")
|
||
# result_content = second_response.output.choices[0].message['content']
|
||
|
||
return final_output
|
||
|
||
|
||
def tutorial_tool():
|
||
return TUTORIAL_TOOL_RETURN
|
||
|
||
|
||
def get_language(message: str) -> str:
|
||
messages = [
|
||
{
|
||
"content": GET_LANGUAGE_PREFIX, # ai message
|
||
"role": "system"
|
||
},
|
||
{
|
||
"content": "Tree", # 用户message
|
||
"role": "user"
|
||
},
|
||
{
|
||
"content": "English", # 用户message
|
||
"role": "assistant"
|
||
},
|
||
{
|
||
"content": "玩具", # 用户message
|
||
"role": "user"
|
||
},
|
||
{
|
||
"content": "Chinese", # 用户message
|
||
"role": "assistant"
|
||
},
|
||
{
|
||
"content": message, # 用户message
|
||
"role": "user"
|
||
}
|
||
]
|
||
|
||
first_response = get_assistant_response(messages)
|
||
assistant_output = first_response.output.choices[0].message.content
|
||
logging.info(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}")
|
||
# print(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}")
|
||
return assistant_output
|
||
|
||
|
||
if __name__ == '__main__':
|
||
for _ in range(4):
|
||
get_language("森林")
|