从向量数据库中检索图片并集成到chat-robot

This commit is contained in:
2024-10-29 16:50:46 +08:00
parent b9d2b510a3
commit aca90159d3
7 changed files with 217 additions and 25 deletions

View File

@@ -0,0 +1,36 @@
import json
import logging
from http.client import HTTPException
from fastapi import APIRouter
from app.schemas.query_image import QueryImageModel
from app.schemas.response_template import ResponseModel
from app.service.search_image_with_text.service import query
router = APIRouter()
logger = logging.getLogger()
@router.post("/query_image")
def query_image(request_data: QueryImageModel):
"""
对话机器人
创建一个具有以下参数的请求体:
- **gender**: 性别
- **content**: 用户输入的内容
示例参数:
{
"gender": "male",
"content": "give me a long sleeve blouse",
}
"""
try:
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = query(request_data.gender, request_data.content)
logger.info(f"query_image response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"query_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api import api_attribute_retrieve from app.api import api_attribute_retrieve, api_query_image
from app.api import api_brighten from app.api import api_brighten
from app.api import api_chat_robot from app.api import api_chat_robot
from app.api import api_design from app.api import api_design
@@ -23,3 +23,4 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api") router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class QueryImageModel(BaseModel):
gender: str
content: str

View File

@@ -100,7 +100,7 @@ def chat(post_data):
# session_key=f"buffer:{user_id}:{session_id}", # session_key=f"buffer:{user_id}:{session_id}",
# ) # )
final_outputs = CallQWen.call_with_messages(input_message) final_outputs = CallQWen.call_with_messages(input_message, gender)
# api_response = { # api_response = {
# 'user_id': user_id, # 'user_id': user_id,
# 'session_id': session_id, # 'session_id': session_id,

View File

@@ -1,16 +1,31 @@
# FASHION_CHAT_BOT_PREFIX = """
# You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
# The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
# Remember your answer should be very precise and the final output answer should not exceed 20 words.
#
# You may encounter the following types of questions:
# 1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database.
# Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables.
# Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results.
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
# You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
# If the question does not seem related to the database, just return "I don't know" as the answer.
#
# 2) If the query related to current events, you should use internet_search to seek help from the internet.
#
# 3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant.
#
# Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
# """
FASHION_CHAT_BOT_PREFIX = """ FASHION_CHAT_BOT_PREFIX = """
You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
Remember your answer should be very precise and the final output answer should not exceed 20 words. Remember your answer should be very precise and the final output answer should not exceed 20 words.
You may encounter the following types of questions: You may encounter the following types of questions:
1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. 1) If you need to query information related to clothing retrieval, please use the get_image_from_vector_db tool.
Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables.
Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
2) If the query related to current events, you should use internet_search to seek help from the internet. 2) If the query related to current events, you should use internet_search to seek help from the internet.
@@ -37,15 +52,19 @@ ANSWER_FORMAT_SUFFIX = """
My final answer are limited to 20 words and be as much precise as possible. My final answer are limited to 20 words and be as much precise as possible.
""" """
# TOOLS_FUNCTIONS_SUFFIX = (
# "If the input involves clothing queries,"
# "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."
# "All SQL statements must use 'ORDER BY RAND()', for example:"
# "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
# "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'"
# "If the input does not involve clothing queries, "
# "I should engage in conversation as an assistant or search from internet with internet_search tool."
# "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'"
# "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
# )
TOOLS_FUNCTIONS_SUFFIX = ( TOOLS_FUNCTIONS_SUFFIX = (
"If the input involves clothing queries," "If the input involves clothing queries,please use the get_image_from_vector_db tool."
"I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."
"All SQL statements must use 'ORDER BY RAND()', for example:"
"Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'"
"If the input does not involve clothing queries, "
"I should engage in conversation as an assistant or search from internet with internet_search tool."
"If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'"
"Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
) )

View File

@@ -8,6 +8,7 @@ from app.core.config import *
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
from app.service.search_image_with_text.service import query
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
@@ -32,6 +33,12 @@ query_database_description = (
"order by rand() LIMIT 2'" "order by rand() LIMIT 2'"
) )
query_vector_db_description = (
"Use this tool to find the clothing images that users need. "
"If the user's input includes clothing types such as blouse, skirt, dress, outerwear, pants, or trousers, please use this tool. "
"The input for the tool is the string provided by the user."
)
tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials." tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string") "Input is an empty string")
@@ -105,15 +112,37 @@ tools = [
"function": { "function": {
"name": "tutorial_tool", "name": "tutorial_tool",
"description": tutorial_description, "description": tutorial_description,
# "parameters": {
# "type": "object",
# "properties": {
# "sql_string": {
# "type": "string",
# "description": "由模型生成的sql语句"
# }
# }
# },
}
},
{
"type": "function",
"function": {
"name": "get_image_from_vector_db",
"description": query_vector_db_description,
"parameters": { "parameters": {
"type": "object", "parameters": {
"properties": { "type": "object",
"sql_string": { "properties": {
"type": "string", "gender": {
"description": "由模型生成的sql语句" "type": "string",
"description": "性别"
},
"content": {
"type": "string",
"description": "用户描述"
}
} }
} },
}, }
} }
} }
] ]
@@ -150,6 +179,10 @@ def query_database(sql_string):
return CustomDatabase.run(db, sql_string) return CustomDatabase.run(db, sql_string)
def get_image_from_vector_db(gender, content):
return query(gender, content)
@retry(exceptions=NewConnectionError, tries=3, delay=1) @retry(exceptions=NewConnectionError, tries=3, delay=1)
def get_response(messages): def get_response(messages):
response = Generation.call( response = Generation.call(
@@ -164,7 +197,8 @@ def get_response(messages):
return response return response
def call_with_messages(message): def call_with_messages(message, gender):
user_input = message
print('\n') print('\n')
# messages = [ # messages = [
# { # {
@@ -235,6 +269,12 @@ def call_with_messages(message):
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()} tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
flag = False flag = False
result_content = tool_info['content'] result_content = tool_info['content']
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
'content': get_image_from_vector_db(gender, user_input)}
flag = False
result_content = tool_info['content']
response_type = "image"
print(f"工具输出信息:{tool_info['content']}\n") print(f"工具输出信息:{tool_info['content']}\n")
messages.append(tool_info) messages.append(tool_info)
@@ -257,5 +297,6 @@ def call_with_messages(message):
def tutorial_tool(): def tutorial_tool():
return TUTORIAL_TOOL_RETURN return TUTORIAL_TOOL_RETURN
if __name__ == '__main__': if __name__ == '__main__':
call_with_messages() call_with_messages()

View File

@@ -0,0 +1,89 @@
import chromadb
import hashlib
import pandas as pd
from chromadb.config import Settings
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
from tqdm import tqdm
# 读取 csv 文件
csv_file_path = r'D:/Files/csv/output/output.csv'
image_path = r'D:/images-clean'
df = pd.read_csv(csv_file_path, encoding='Windows-1252')
# 创建 Chroma 客户端
client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db"))
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db"))
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db"))
# 创建集合
embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large")
def create_collection():
collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn)
# 存储数据,包括自定义属性
images_description = []
images_metadata = []
ids = []
batch_size = 41666 # 最大批量大小
for index, row in tqdm(df.iterrows()):
# 将图片的md5作为id
with open(image_path + row['path'], 'rb') as f:
image_data = f.read()
md5_value = hashlib.md5(image_data).hexdigest()
ids.append(md5_value)
images_description.append(row['description'])
images_metadata.append({
"gender": row['gender'],
"path": row['path']
})
# 将数据添加到集合
# 每达到 batch_size 就执行一次 upsert
if len(ids) >= batch_size:
collection.upsert(
ids=list(ids),
documents=images_description,
metadatas=images_metadata # 添加自定义属性
)
# 清空列表以准备下一批数据
ids.clear()
images_description.clear()
images_metadata.clear()
if ids:
collection.upsert(
ids=list(ids),
documents=images_description,
metadatas=images_metadata # 添加自定义属性
)
print("Data successfully stored in the vector database.")
def query(gender, content):
collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn)
# 6. 查询相似内容
user_gender = gender # 用户输入的性别
user_content = content # 用户输入的内容
results = collection.query(
query_texts=user_content,
n_results=5, # 返回前 5 个结果
where={"gender": user_gender} # 根据性别过滤
)
# 输出结果
resp = []
for document, result in zip(results['documents'][0], results['metadatas'][0]):
# print("Path:", result['path'])
# print("Content:", document)
resp.append(result['path'])
return resp
if __name__ == '__main__':
# create_collection()
query("female", "I need a long sleeve dress")