diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 47a9dde..98f6ac4 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -7,10 +7,10 @@ from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaE from tqdm import tqdm # 读取 csv 文件 -# csv_file_path = r'D:/Files/csv/output/output.csv' -# image_path = r'D:/images-clean' +csv_file_path = r'D:/Files/csv/output/output.csv' +image_path = r'D:/images-clean' -# df = pd.read_csv(csv_file_path, encoding='Windows-1252') +df = pd.read_csv(csv_file_path, encoding='Windows-1252') # 创建 Chroma 客户端 client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) @@ -20,47 +20,47 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -# def create_collection(): -# collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) -# -# # 存储数据,包括自定义属性 -# images_description = [] -# images_metadata = [] -# ids = [] -# batch_size = 41666 # 最大批量大小 -# for index, row in tqdm(df.iterrows()): -# # 将图片的md5作为id -# with open(image_path + row['path'], 'rb') as f: -# image_data = f.read() -# md5_value = hashlib.md5(image_data).hexdigest() -# ids.append(md5_value) -# images_description.append(row['description']) -# images_metadata.append({ -# "gender": row['gender'], -# "path": row['path'] -# }) -# -# # 将数据添加到集合 -# # 每达到 batch_size 就执行一次 upsert -# if len(ids) >= batch_size: -# collection.upsert( -# ids=list(ids), -# documents=images_description, -# metadatas=images_metadata # 添加自定义属性 -# ) -# # 清空列表以准备下一批数据 -# ids.clear() -# images_description.clear() -# images_metadata.clear() -# -# if ids: -# collection.upsert( -# ids=list(ids), -# documents=images_description, -# metadatas=images_metadata # 添加自定义属性 -# ) -# -# print("Data successfully stored in the vector database.") +def create_collection(): + collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) + + # 存储数据,包括自定义属性 + images_description = [] + images_metadata = [] + ids = [] + batch_size = 41666 # 最大批量大小 + for index, row in tqdm(df.iterrows()): + # 将图片的md5作为id + with open(image_path + row['path'], 'rb') as f: + image_data = f.read() + md5_value = hashlib.md5(image_data).hexdigest() + ids.append(md5_value) + images_description.append(row['description']) + images_metadata.append({ + "gender": row['gender'], + "path": row['path'] + }) + + # 将数据添加到集合 + # 每达到 batch_size 就执行一次 upsert + if len(ids) >= batch_size: + collection.upsert( + ids=list(ids), + documents=images_description, + metadatas=images_metadata # 添加自定义属性 + ) + # 清空列表以准备下一批数据 + ids.clear() + images_description.clear() + images_metadata.clear() + + if ids: + collection.upsert( + ids=list(ids), + documents=images_description, + metadatas=images_metadata # 添加自定义属性 + ) + + print("Data successfully stored in the vector database.") def query(gender, content):