diff --git a/app/api/api_query_image.py b/app/api/api_query_image.py new file mode 100644 index 0000000..d27c67b --- /dev/null +++ b/app/api/api_query_image.py @@ -0,0 +1,36 @@ +import json +import logging +from http.client import HTTPException + +from fastapi import APIRouter + +from app.schemas.query_image import QueryImageModel +from app.schemas.response_template import ResponseModel +from app.service.search_image_with_text.service import query + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/query_image") +def query_image(request_data: QueryImageModel): + """ + 对话机器人 + 创建一个具有以下参数的请求体: + - **gender**: 性别 + - **content**: 用户输入的内容 + + 示例参数: + { + "gender": "male", + "content": "give me a long sleeve blouse", + } + """ + try: + logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}") + data = query(request_data.gender, request_data.content) + logger.info(f"query_image response @@@@@@:{json.dumps(data)}") + except Exception as e: + logger.warning(f"query_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/schemas/query_image.py b/app/schemas/query_image.py new file mode 100644 index 0000000..147603f --- /dev/null +++ b/app/schemas/query_image.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class QueryImageModel(BaseModel): + gender: str + content: str diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py new file mode 100644 index 0000000..712050f --- /dev/null +++ b/app/service/search_image_with_text/service.py @@ -0,0 +1,89 @@ +import chromadb +import hashlib + +import pandas as pd +from chromadb.config import Settings +from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction +from tqdm import tqdm + +# 读取 csv 文件 +# csv_file_path = r'D:/Files/csv/output/output.csv' +# image_path = r'D:/images-clean' + +# df = pd.read_csv(csv_file_path, encoding='Windows-1252') + +# 创建 Chroma 客户端 +client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) +# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) +# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) +# 创建集合 +embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") + + +# def create_collection(): +# collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) +# +# # 存储数据,包括自定义属性 +# images_description = [] +# images_metadata = [] +# ids = [] +# batch_size = 41666 # 最大批量大小 +# for index, row in tqdm(df.iterrows()): +# # 将图片的md5作为id +# with open(image_path + row['path'], 'rb') as f: +# image_data = f.read() +# md5_value = hashlib.md5(image_data).hexdigest() +# ids.append(md5_value) +# images_description.append(row['description']) +# images_metadata.append({ +# "gender": row['gender'], +# "path": row['path'] +# }) +# +# # 将数据添加到集合 +# # 每达到 batch_size 就执行一次 upsert +# if len(ids) >= batch_size: +# collection.upsert( +# ids=list(ids), +# documents=images_description, +# metadatas=images_metadata # 添加自定义属性 +# ) +# # 清空列表以准备下一批数据 +# ids.clear() +# images_description.clear() +# images_metadata.clear() +# +# if ids: +# collection.upsert( +# ids=list(ids), +# documents=images_description, +# metadatas=images_metadata # 添加自定义属性 +# ) +# +# print("Data successfully stored in the vector database.") + + +def query(gender, content): + collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn) + # 6. 查询相似内容 + user_gender = gender # 用户输入的性别 + user_content = content # 用户输入的内容 + + results = collection.query( + query_texts=user_content, + n_results=5, # 返回前 5 个结果 + where={"gender": user_gender} # 根据性别过滤 + ) + + # 输出结果 + resp = [] + for document, result in zip(results['documents'][0], results['metadatas'][0]): + # print("Path:", result['path']) + # print("Content:", document) + resp.append(result['path']) + return resp + + +if __name__ == '__main__': + # create_collection() + query("female", "I need a long sleeve dress")