Files
AiDA_Python/app/service/search_image_with_text/service.py
2024-11-08 14:35:23 +08:00

93 lines
3.2 KiB
Python

import chromadb
import hashlib
import pandas as pd
from chromadb.config import Settings
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
from tqdm import tqdm
from app.core.config import OLLAMA_URL
# 读取 csv 文件
# csv_file_path = r'D:/Files/csv/output/output.csv'
# image_path = r'D:/images-clean'
# df = pd.read_csv(csv_file_path, encoding='Windows-1252')
# 创建 Chroma 客户端
client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db"))
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db"))
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db"))
# 创建集合
# embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large")
embedding_fn = OllamaEmbeddingFunction(url=OLLAMA_URL, model_name="mxbai-embed-large")
# def create_collection():
# collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn)
#
# # 存储数据,包括自定义属性
# images_description = []
# images_metadata = []
# ids = []
# batch_size = 41666 # 最大批量大小
# for index, row in tqdm(df.iterrows()):
# # 将图片的md5作为id
# with open(image_path + row['path'], 'rb') as f:
# image_data = f.read()
# md5_value = hashlib.md5(image_data).hexdigest()
# ids.append(md5_value)
# images_description.append(row['description'])
# images_metadata.append({
# "gender": row['gender'],
# "path": row['path']
# })
#
# # 将数据添加到集合
# # 每达到 batch_size 就执行一次 upsert
# if len(ids) >= batch_size:
# collection.upsert(
# ids=list(ids),
# documents=images_description,
# metadatas=images_metadata # 添加自定义属性
# )
# # 清空列表以准备下一批数据
# ids.clear()
# images_description.clear()
# images_metadata.clear()
#
# if ids:
# collection.upsert(
# ids=list(ids),
# documents=images_description,
# metadatas=images_metadata # 添加自定义属性
# )
#
# print("Data successfully stored in the vector database.")
def query(gender, content):
collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn)
# 6. 查询相似内容
user_gender = gender.lower() # 用户输入的性别
user_content = content # 用户输入的内容
results = collection.query(
query_texts=user_content,
n_results=5, # 返回前 5 个结果
where={"gender": user_gender} # 根据性别过滤
)
# 输出结果
resp = []
for document, result in zip(results['documents'][0], results['metadatas'][0]):
# print("Path:", result['path'])
# print("Content:", document)
resp.append(result['path'])
return resp
if __name__ == '__main__':
# create_collection()
query("female", "I need a long sleeve dress")