新增 取消agent配饰(保留鞋子)推荐,改为默认随机配饰搭配 使用json文件补充stylist删除掉的必要配饰
This commit is contained in:
@@ -296,21 +296,36 @@ class AsyncStylistAgent:
|
||||
async def _get_random_accessories(self, stylist, item_count):
|
||||
stylist_item = []
|
||||
stylist_item_ids = []
|
||||
|
||||
filter_items = [
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"item_group_id": {"$ne": "Shoes"}},
|
||||
{"modality": "image"}
|
||||
]
|
||||
random_items = []
|
||||
|
||||
for i in stylist:
|
||||
# 1. 根据stylist要求抽取item
|
||||
query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False)
|
||||
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
|
||||
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
|
||||
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
|
||||
filter_items.append({"item_group_id": {"$ne": i['category']}})
|
||||
|
||||
accessories_count = 9 - item_count - len(stylist_item)
|
||||
|
||||
if accessories_count > 0:
|
||||
accessories_count = random.randint(1, accessories_count)
|
||||
if accessories_count > 4:
|
||||
accessories_count = 4
|
||||
for i in range(accessories_count):
|
||||
random_poll = self.local_db.load_filtered_ids(filter_items)
|
||||
# 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item
|
||||
random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||
random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||
item = self.local_db.random_get_accessories(random.choice(random_poll))
|
||||
filter_items.append({"item_group_id": {"$ne": item['metadatas'][0]['category']}})
|
||||
random_items.append(item['metadatas'][0])
|
||||
|
||||
all_items = stylist_item + random_items
|
||||
|
||||
else:
|
||||
all_items = stylist_item
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
import chromadb
|
||||
@@ -17,8 +18,12 @@ class VectorDatabase():
|
||||
|
||||
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
||||
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||
self.cache_filtered_ids = self.load_filtered_ids()
|
||||
self.total_count = len(self.cache_filtered_ids)
|
||||
# self.cache_filtered_ids = self.load_filtered_ids([
|
||||
# {"item_group_id": {"$ne": "Clothing"}},
|
||||
# {"item_group_id": {"$ne": "Shoes"}},
|
||||
# {"modality": "image"}
|
||||
# ])
|
||||
# self.total_count = len(self.cache_filtered_ids)
|
||||
|
||||
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||
@@ -63,14 +68,11 @@ class VectorDatabase():
|
||||
)
|
||||
return results
|
||||
|
||||
def load_filtered_ids(self):
|
||||
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
||||
def load_filtered_ids(self, filter_item):
|
||||
# print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
||||
start_time = time.time()
|
||||
FILTER_CRITERIA = {
|
||||
"$and": [
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes
|
||||
{"modality": "image"}
|
||||
]
|
||||
"$and": filter_item
|
||||
}
|
||||
MAX_LIMIT = 100000
|
||||
|
||||
@@ -82,7 +84,8 @@ class VectorDatabase():
|
||||
include=[]
|
||||
)
|
||||
all_matched_ids = all_ids_results['ids']
|
||||
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
||||
# print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
||||
print(time.time() - start_time)
|
||||
return all_matched_ids
|
||||
|
||||
except Exception as e:
|
||||
@@ -114,20 +117,29 @@ if __name__ == '__main__':
|
||||
'count': 2,
|
||||
'category': "Jewelry"
|
||||
}
|
||||
|
||||
max_len = 5
|
||||
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
||||
A = local_db.load_filtered_ids([
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"item_group_id": {"$ne": "Shoes"}},
|
||||
{"modality": "image"}
|
||||
])
|
||||
# print(db.random_get_accessories())
|
||||
|
||||
query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
|
||||
|
||||
results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
|
||||
# 2. 从结果集中抽 stylist['count'] 个item
|
||||
stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
|
||||
stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
|
||||
|
||||
# 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item
|
||||
accessories_count = 9 - max_len - stylist['count']
|
||||
|
||||
random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||
random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||
all_items = stylist_item + random_items
|
||||
start_time = time.time()
|
||||
X = local_db.random_get_accessories(['ELI699_img'])
|
||||
print(X)
|
||||
print(time.time() - start_time)
|
||||
# query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
|
||||
#
|
||||
# results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
|
||||
# # 2. 从结果集中抽 stylist['count'] 个item
|
||||
# stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
|
||||
# stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
|
||||
#
|
||||
# # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item
|
||||
# accessories_count = 9 - max_len - stylist['count']
|
||||
#
|
||||
# random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||
# random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||
# all_items = stylist_item + random_items
|
||||
|
||||
Reference in New Issue
Block a user