feat dockerfile 修改

fix
This commit is contained in:
zhouchengrong
2024-10-29 17:17:30 +08:00
parent 31d7f55402
commit 1ba67d0bf7

View File

@@ -7,10 +7,10 @@ from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaE
from tqdm import tqdm from tqdm import tqdm
# 读取 csv 文件 # 读取 csv 文件
csv_file_path = r'D:/Files/csv/output/output.csv' # csv_file_path = r'D:/Files/csv/output/output.csv'
image_path = r'D:/images-clean' # image_path = r'D:/images-clean'
df = pd.read_csv(csv_file_path, encoding='Windows-1252') # df = pd.read_csv(csv_file_path, encoding='Windows-1252')
# 创建 Chroma 客户端 # 创建 Chroma 客户端
client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db"))
@@ -20,47 +20,47 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector
embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large")
def create_collection(): # def create_collection():
collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) # collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn)
#
# 存储数据,包括自定义属性 # # 存储数据,包括自定义属性
images_description = [] # images_description = []
images_metadata = [] # images_metadata = []
ids = [] # ids = []
batch_size = 41666 # 最大批量大小 # batch_size = 41666 # 最大批量大小
for index, row in tqdm(df.iterrows()): # for index, row in tqdm(df.iterrows()):
# 将图片的md5作为id # # 将图片的md5作为id
with open(image_path + row['path'], 'rb') as f: # with open(image_path + row['path'], 'rb') as f:
image_data = f.read() # image_data = f.read()
md5_value = hashlib.md5(image_data).hexdigest() # md5_value = hashlib.md5(image_data).hexdigest()
ids.append(md5_value) # ids.append(md5_value)
images_description.append(row['description']) # images_description.append(row['description'])
images_metadata.append({ # images_metadata.append({
"gender": row['gender'], # "gender": row['gender'],
"path": row['path'] # "path": row['path']
}) # })
#
# 将数据添加到集合 # # 将数据添加到集合
# 每达到 batch_size 就执行一次 upsert # # 每达到 batch_size 就执行一次 upsert
if len(ids) >= batch_size: # if len(ids) >= batch_size:
collection.upsert( # collection.upsert(
ids=list(ids), # ids=list(ids),
documents=images_description, # documents=images_description,
metadatas=images_metadata # 添加自定义属性 # metadatas=images_metadata # 添加自定义属性
) # )
# 清空列表以准备下一批数据 # # 清空列表以准备下一批数据
ids.clear() # ids.clear()
images_description.clear() # images_description.clear()
images_metadata.clear() # images_metadata.clear()
#
if ids: # if ids:
collection.upsert( # collection.upsert(
ids=list(ids), # ids=list(ids),
documents=images_description, # documents=images_description,
metadatas=images_metadata # 添加自定义属性 # metadatas=images_metadata # 添加自定义属性
) # )
#
print("Data successfully stored in the vector database.") # print("Data successfully stored in the vector database.")
def query(gender, content): def query(gender, content):