从向量数据库中检索图片并集成到chat-robot
This commit is contained in:
36
app/api/api_query_image.py
Normal file
36
app/api/api_query_image.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from http.client import HTTPException
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.schemas.query_image import QueryImageModel
|
||||||
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.service.search_image_with_text.service import query
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/query_image")
|
||||||
|
def query_image(request_data: QueryImageModel):
|
||||||
|
"""
|
||||||
|
对话机器人
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **gender**: 性别
|
||||||
|
- **content**: 用户输入的内容
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"gender": "male",
|
||||||
|
"content": "give me a long sleeve blouse",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||||
|
data = query(request_data.gender, request_data.content)
|
||||||
|
logger.info(f"query_image response @@@@@@:{json.dumps(data)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"query_image Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=data)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api import api_attribute_retrieve
|
from app.api import api_attribute_retrieve, api_query_image
|
||||||
from app.api import api_brighten
|
from app.api import api_brighten
|
||||||
from app.api import api_chat_robot
|
from app.api import api_chat_robot
|
||||||
from app.api import api_design
|
from app.api import api_design
|
||||||
@@ -23,3 +23,4 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
|
|||||||
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||||
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
||||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||||
|
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||||
6
app/schemas/query_image.py
Normal file
6
app/schemas/query_image.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class QueryImageModel(BaseModel):
|
||||||
|
gender: str
|
||||||
|
content: str
|
||||||
@@ -100,7 +100,7 @@ def chat(post_data):
|
|||||||
# session_key=f"buffer:{user_id}:{session_id}",
|
# session_key=f"buffer:{user_id}:{session_id}",
|
||||||
# )
|
# )
|
||||||
|
|
||||||
final_outputs = CallQWen.call_with_messages(input_message)
|
final_outputs = CallQWen.call_with_messages(input_message, gender)
|
||||||
# api_response = {
|
# api_response = {
|
||||||
# 'user_id': user_id,
|
# 'user_id': user_id,
|
||||||
# 'session_id': session_id,
|
# 'session_id': session_id,
|
||||||
|
|||||||
@@ -1,16 +1,31 @@
|
|||||||
|
# FASHION_CHAT_BOT_PREFIX = """
|
||||||
|
# You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
|
||||||
|
# The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
|
||||||
|
# Remember your answer should be very precise and the final output answer should not exceed 20 words.
|
||||||
|
#
|
||||||
|
# You may encounter the following types of questions:
|
||||||
|
# 1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database.
|
||||||
|
# Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables.
|
||||||
|
# Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results.
|
||||||
|
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
||||||
|
# You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
|
||||||
|
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||||
|
# If the question does not seem related to the database, just return "I don't know" as the answer.
|
||||||
|
#
|
||||||
|
# 2) If the query related to current events, you should use internet_search to seek help from the internet.
|
||||||
|
#
|
||||||
|
# 3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant.
|
||||||
|
#
|
||||||
|
# Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
|
||||||
|
# """
|
||||||
|
|
||||||
FASHION_CHAT_BOT_PREFIX = """
|
FASHION_CHAT_BOT_PREFIX = """
|
||||||
You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
|
You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
|
||||||
The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
|
The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
|
||||||
Remember your answer should be very precise and the final output answer should not exceed 20 words.
|
Remember your answer should be very precise and the final output answer should not exceed 20 words.
|
||||||
|
|
||||||
You may encounter the following types of questions:
|
You may encounter the following types of questions:
|
||||||
1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database.
|
1) If you need to query information related to clothing retrieval, please use the get_image_from_vector_db tool.
|
||||||
Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables.
|
|
||||||
Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results.
|
|
||||||
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
|
||||||
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
|
|
||||||
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
|
||||||
If the question does not seem related to the database, just return "I don't know" as the answer.
|
|
||||||
|
|
||||||
2) If the query related to current events, you should use internet_search to seek help from the internet.
|
2) If the query related to current events, you should use internet_search to seek help from the internet.
|
||||||
|
|
||||||
@@ -37,15 +52,19 @@ ANSWER_FORMAT_SUFFIX = """
|
|||||||
My final answer are limited to 20 words and be as much precise as possible.
|
My final answer are limited to 20 words and be as much precise as possible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TOOLS_FUNCTIONS_SUFFIX = (
|
||||||
|
# "If the input involves clothing queries,"
|
||||||
|
# "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."
|
||||||
|
# "All SQL statements must use 'ORDER BY RAND()', for example:"
|
||||||
|
# "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
|
||||||
|
# "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'"
|
||||||
|
# "If the input does not involve clothing queries, "
|
||||||
|
# "I should engage in conversation as an assistant or search from internet with internet_search tool."
|
||||||
|
# "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'"
|
||||||
|
# "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
|
||||||
|
# )
|
||||||
TOOLS_FUNCTIONS_SUFFIX = (
|
TOOLS_FUNCTIONS_SUFFIX = (
|
||||||
"If the input involves clothing queries,"
|
"If the input involves clothing queries,please use the get_image_from_vector_db tool."
|
||||||
"I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."
|
|
||||||
"All SQL statements must use 'ORDER BY RAND()', for example:"
|
|
||||||
"Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
|
|
||||||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'"
|
|
||||||
"If the input does not involve clothing queries, "
|
|
||||||
"I should engage in conversation as an assistant or search from internet with internet_search tool."
|
|
||||||
"If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'"
|
|
||||||
"Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
|
"Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.core.config import *
|
|||||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||||
from app.service.chat_robot.script.database import CustomDatabase
|
from app.service.chat_robot.script.database import CustomDatabase
|
||||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
|
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
|
||||||
|
from app.service.search_image_with_text.service import query
|
||||||
|
|
||||||
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||||
|
|
||||||
@@ -32,6 +33,12 @@ query_database_description = (
|
|||||||
"order by rand() LIMIT 2'"
|
"order by rand() LIMIT 2'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
query_vector_db_description = (
|
||||||
|
"Use this tool to find the clothing images that users need. "
|
||||||
|
"If the user's input includes clothing types such as blouse, skirt, dress, outerwear, pants, or trousers, please use this tool. "
|
||||||
|
"The input for the tool is the string provided by the user."
|
||||||
|
)
|
||||||
|
|
||||||
tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||||
"Input is an empty string")
|
"Input is an empty string")
|
||||||
|
|
||||||
@@ -105,15 +112,37 @@ tools = [
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "tutorial_tool",
|
"name": "tutorial_tool",
|
||||||
"description": tutorial_description,
|
"description": tutorial_description,
|
||||||
|
# "parameters": {
|
||||||
|
# "type": "object",
|
||||||
|
# "properties": {
|
||||||
|
# "sql_string": {
|
||||||
|
# "type": "string",
|
||||||
|
# "description": "由模型生成的sql语句"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# },
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_image_from_vector_db",
|
||||||
|
"description": query_vector_db_description,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"parameters": {
|
||||||
"properties": {
|
"type": "object",
|
||||||
"sql_string": {
|
"properties": {
|
||||||
"type": "string",
|
"gender": {
|
||||||
"description": "由模型生成的sql语句"
|
"type": "string",
|
||||||
|
"description": "性别"
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "用户描述"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -150,6 +179,10 @@ def query_database(sql_string):
|
|||||||
return CustomDatabase.run(db, sql_string)
|
return CustomDatabase.run(db, sql_string)
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_from_vector_db(gender, content):
|
||||||
|
return query(gender, content)
|
||||||
|
|
||||||
|
|
||||||
@retry(exceptions=NewConnectionError, tries=3, delay=1)
|
@retry(exceptions=NewConnectionError, tries=3, delay=1)
|
||||||
def get_response(messages):
|
def get_response(messages):
|
||||||
response = Generation.call(
|
response = Generation.call(
|
||||||
@@ -164,7 +197,8 @@ def get_response(messages):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def call_with_messages(message):
|
def call_with_messages(message, gender):
|
||||||
|
user_input = message
|
||||||
print('\n')
|
print('\n')
|
||||||
# messages = [
|
# messages = [
|
||||||
# {
|
# {
|
||||||
@@ -235,6 +269,12 @@ def call_with_messages(message):
|
|||||||
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
|
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
|
||||||
flag = False
|
flag = False
|
||||||
result_content = tool_info['content']
|
result_content = tool_info['content']
|
||||||
|
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
|
||||||
|
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
|
||||||
|
'content': get_image_from_vector_db(gender, user_input)}
|
||||||
|
flag = False
|
||||||
|
result_content = tool_info['content']
|
||||||
|
response_type = "image"
|
||||||
|
|
||||||
print(f"工具输出信息:{tool_info['content']}\n")
|
print(f"工具输出信息:{tool_info['content']}\n")
|
||||||
messages.append(tool_info)
|
messages.append(tool_info)
|
||||||
@@ -257,5 +297,6 @@ def call_with_messages(message):
|
|||||||
def tutorial_tool():
|
def tutorial_tool():
|
||||||
return TUTORIAL_TOOL_RETURN
|
return TUTORIAL_TOOL_RETURN
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
call_with_messages()
|
call_with_messages()
|
||||||
|
|||||||
89
app/service/search_image_with_text/service.py
Normal file
89
app/service/search_image_with_text/service.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import chromadb
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from chromadb.config import Settings
|
||||||
|
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# 读取 csv 文件
|
||||||
|
csv_file_path = r'D:/Files/csv/output/output.csv'
|
||||||
|
image_path = r'D:/images-clean'
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_file_path, encoding='Windows-1252')
|
||||||
|
|
||||||
|
# 创建 Chroma 客户端
|
||||||
|
client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db"))
|
||||||
|
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db"))
|
||||||
|
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db"))
|
||||||
|
# 创建集合
|
||||||
|
embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large")
|
||||||
|
|
||||||
|
|
||||||
|
def create_collection():
|
||||||
|
collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn)
|
||||||
|
|
||||||
|
# 存储数据,包括自定义属性
|
||||||
|
images_description = []
|
||||||
|
images_metadata = []
|
||||||
|
ids = []
|
||||||
|
batch_size = 41666 # 最大批量大小
|
||||||
|
for index, row in tqdm(df.iterrows()):
|
||||||
|
# 将图片的md5作为id
|
||||||
|
with open(image_path + row['path'], 'rb') as f:
|
||||||
|
image_data = f.read()
|
||||||
|
md5_value = hashlib.md5(image_data).hexdigest()
|
||||||
|
ids.append(md5_value)
|
||||||
|
images_description.append(row['description'])
|
||||||
|
images_metadata.append({
|
||||||
|
"gender": row['gender'],
|
||||||
|
"path": row['path']
|
||||||
|
})
|
||||||
|
|
||||||
|
# 将数据添加到集合
|
||||||
|
# 每达到 batch_size 就执行一次 upsert
|
||||||
|
if len(ids) >= batch_size:
|
||||||
|
collection.upsert(
|
||||||
|
ids=list(ids),
|
||||||
|
documents=images_description,
|
||||||
|
metadatas=images_metadata # 添加自定义属性
|
||||||
|
)
|
||||||
|
# 清空列表以准备下一批数据
|
||||||
|
ids.clear()
|
||||||
|
images_description.clear()
|
||||||
|
images_metadata.clear()
|
||||||
|
|
||||||
|
if ids:
|
||||||
|
collection.upsert(
|
||||||
|
ids=list(ids),
|
||||||
|
documents=images_description,
|
||||||
|
metadatas=images_metadata # 添加自定义属性
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Data successfully stored in the vector database.")
|
||||||
|
|
||||||
|
|
||||||
|
def query(gender, content):
|
||||||
|
collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn)
|
||||||
|
# 6. 查询相似内容
|
||||||
|
user_gender = gender # 用户输入的性别
|
||||||
|
user_content = content # 用户输入的内容
|
||||||
|
|
||||||
|
results = collection.query(
|
||||||
|
query_texts=user_content,
|
||||||
|
n_results=5, # 返回前 5 个结果
|
||||||
|
where={"gender": user_gender} # 根据性别过滤
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输出结果
|
||||||
|
resp = []
|
||||||
|
for document, result in zip(results['documents'][0], results['metadatas'][0]):
|
||||||
|
# print("Path:", result['path'])
|
||||||
|
# print("Content:", document)
|
||||||
|
resp.append(result['path'])
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# create_collection()
|
||||||
|
query("female", "I need a long sleeve dress")
|
||||||
Reference in New Issue
Block a user