import chromadb import hashlib import pandas as pd from chromadb.config import Settings from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction from tqdm import tqdm # 读取 csv 文件 # csv_file_path = r'D:/Files/csv/output/output.csv' # image_path = r'D:/images-clean' # df = pd.read_csv(csv_file_path, encoding='Windows-1252') # 创建 Chroma 客户端 client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 # embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.243:11434/api/embeddings", model_name="mxbai-embed-large") # def create_collection(): # collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) # # # 存储数据,包括自定义属性 # images_description = [] # images_metadata = [] # ids = [] # batch_size = 41666 # 最大批量大小 # for index, row in tqdm(df.iterrows()): # # 将图片的md5作为id # with open(image_path + row['path'], 'rb') as f: # image_data = f.read() # md5_value = hashlib.md5(image_data).hexdigest() # ids.append(md5_value) # images_description.append(row['description']) # images_metadata.append({ # "gender": row['gender'], # "path": row['path'] # }) # # # 将数据添加到集合 # # 每达到 batch_size 就执行一次 upsert # if len(ids) >= batch_size: # collection.upsert( # ids=list(ids), # documents=images_description, # metadatas=images_metadata # 添加自定义属性 # ) # # 清空列表以准备下一批数据 # ids.clear() # images_description.clear() # images_metadata.clear() # # if ids: # collection.upsert( # ids=list(ids), # documents=images_description, # metadatas=images_metadata # 添加自定义属性 # ) # # print("Data successfully stored in the vector database.") def query(gender, content): collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn) # 6. 查询相似内容 user_gender = gender.lower() # 用户输入的性别 user_content = content # 用户输入的内容 results = collection.query( query_texts=user_content, n_results=5, # 返回前 5 个结果 where={"gender": user_gender} # 根据性别过滤 ) # 输出结果 resp = [] for document, result in zip(results['documents'][0], results['metadatas'][0]): # print("Path:", result['path']) # print("Content:", document) resp.append(result['path']) return resp if __name__ == '__main__': # create_collection() query("female", "I need a long sleeve dress")