从向量数据库中检索图片并集成到chat-robot
This commit is contained in:
@@ -8,6 +8,7 @@ from app.core.config import *
|
||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
|
||||
from app.service.search_image_with_text.service import query
|
||||
|
||||
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||
|
||||
@@ -32,6 +33,12 @@ query_database_description = (
|
||||
"order by rand() LIMIT 2'"
|
||||
)
|
||||
|
||||
query_vector_db_description = (
|
||||
"Use this tool to find the clothing images that users need. "
|
||||
"If the user's input includes clothing types such as blouse, skirt, dress, outerwear, pants, or trousers, please use this tool. "
|
||||
"The input for the tool is the string provided by the user."
|
||||
)
|
||||
|
||||
tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
"Input is an empty string")
|
||||
|
||||
@@ -105,15 +112,37 @@ tools = [
|
||||
"function": {
|
||||
"name": "tutorial_tool",
|
||||
"description": tutorial_description,
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "sql_string": {
|
||||
# "type": "string",
|
||||
# "description": "由模型生成的sql语句"
|
||||
# }
|
||||
# }
|
||||
# },
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_image_from_vector_db",
|
||||
"description": query_vector_db_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql_string": {
|
||||
"type": "string",
|
||||
"description": "由模型生成的sql语句"
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"gender": {
|
||||
"type": "string",
|
||||
"description": "性别"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "用户描述"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -150,6 +179,10 @@ def query_database(sql_string):
|
||||
return CustomDatabase.run(db, sql_string)
|
||||
|
||||
|
||||
def get_image_from_vector_db(gender, content):
|
||||
return query(gender, content)
|
||||
|
||||
|
||||
@retry(exceptions=NewConnectionError, tries=3, delay=1)
|
||||
def get_response(messages):
|
||||
response = Generation.call(
|
||||
@@ -164,7 +197,8 @@ def get_response(messages):
|
||||
return response
|
||||
|
||||
|
||||
def call_with_messages(message):
|
||||
def call_with_messages(message, gender):
|
||||
user_input = message
|
||||
print('\n')
|
||||
# messages = [
|
||||
# {
|
||||
@@ -235,6 +269,12 @@ def call_with_messages(message):
|
||||
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
|
||||
flag = False
|
||||
result_content = tool_info['content']
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
|
||||
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
|
||||
'content': get_image_from_vector_db(gender, user_input)}
|
||||
flag = False
|
||||
result_content = tool_info['content']
|
||||
response_type = "image"
|
||||
|
||||
print(f"工具输出信息:{tool_info['content']}\n")
|
||||
messages.append(tool_info)
|
||||
@@ -257,5 +297,6 @@ def call_with_messages(message):
|
||||
def tutorial_tool():
|
||||
return TUTORIAL_TOOL_RETURN
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
call_with_messages()
|
||||
|
||||
Reference in New Issue
Block a user