新增 取消agent配饰(保留鞋子)推荐,改为默认随机配饰搭配 使用json文件补充stylist删除掉的必要配饰
This commit is contained in:
@@ -296,21 +296,36 @@ class AsyncStylistAgent:
|
|||||||
async def _get_random_accessories(self, stylist, item_count):
|
async def _get_random_accessories(self, stylist, item_count):
|
||||||
stylist_item = []
|
stylist_item = []
|
||||||
stylist_item_ids = []
|
stylist_item_ids = []
|
||||||
|
|
||||||
|
filter_items = [
|
||||||
|
{"item_group_id": {"$ne": "Clothing"}},
|
||||||
|
{"item_group_id": {"$ne": "Shoes"}},
|
||||||
|
{"modality": "image"}
|
||||||
|
]
|
||||||
|
random_items = []
|
||||||
|
|
||||||
for i in stylist:
|
for i in stylist:
|
||||||
# 1. 根据stylist要求抽取item
|
# 1. 根据stylist要求抽取item
|
||||||
query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False)
|
query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False)
|
||||||
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
|
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
|
||||||
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
|
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
|
||||||
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
|
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
|
||||||
|
filter_items.append({"item_group_id": {"$ne": i['category']}})
|
||||||
|
|
||||||
accessories_count = 9 - item_count - len(stylist_item)
|
accessories_count = 9 - item_count - len(stylist_item)
|
||||||
|
|
||||||
if accessories_count > 0:
|
if accessories_count > 0:
|
||||||
accessories_count = random.randint(1, accessories_count)
|
if accessories_count > 4:
|
||||||
# 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item
|
accessories_count = 4
|
||||||
random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
for i in range(accessories_count):
|
||||||
random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas']
|
random_poll = self.local_db.load_filtered_ids(filter_items)
|
||||||
|
# 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item
|
||||||
|
item = self.local_db.random_get_accessories(random.choice(random_poll))
|
||||||
|
filter_items.append({"item_group_id": {"$ne": item['metadatas'][0]['category']}})
|
||||||
|
random_items.append(item['metadatas'][0])
|
||||||
|
|
||||||
all_items = stylist_item + random_items
|
all_items = stylist_item + random_items
|
||||||
|
|
||||||
else:
|
else:
|
||||||
all_items = stylist_item
|
all_items = stylist_item
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import chromadb
|
import chromadb
|
||||||
@@ -17,8 +18,12 @@ class VectorDatabase():
|
|||||||
|
|
||||||
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
||||||
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||||
self.cache_filtered_ids = self.load_filtered_ids()
|
# self.cache_filtered_ids = self.load_filtered_ids([
|
||||||
self.total_count = len(self.cache_filtered_ids)
|
# {"item_group_id": {"$ne": "Clothing"}},
|
||||||
|
# {"item_group_id": {"$ne": "Shoes"}},
|
||||||
|
# {"modality": "image"}
|
||||||
|
# ])
|
||||||
|
# self.total_count = len(self.cache_filtered_ids)
|
||||||
|
|
||||||
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
||||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||||
@@ -63,14 +68,11 @@ class VectorDatabase():
|
|||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def load_filtered_ids(self):
|
def load_filtered_ids(self, filter_item):
|
||||||
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
# print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
||||||
|
start_time = time.time()
|
||||||
FILTER_CRITERIA = {
|
FILTER_CRITERIA = {
|
||||||
"$and": [
|
"$and": filter_item
|
||||||
{"item_group_id": {"$ne": "Clothing"}},
|
|
||||||
{"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes
|
|
||||||
{"modality": "image"}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
MAX_LIMIT = 100000
|
MAX_LIMIT = 100000
|
||||||
|
|
||||||
@@ -82,7 +84,8 @@ class VectorDatabase():
|
|||||||
include=[]
|
include=[]
|
||||||
)
|
)
|
||||||
all_matched_ids = all_ids_results['ids']
|
all_matched_ids = all_ids_results['ids']
|
||||||
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
# print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
||||||
|
print(time.time() - start_time)
|
||||||
return all_matched_ids
|
return all_matched_ids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -114,20 +117,29 @@ if __name__ == '__main__':
|
|||||||
'count': 2,
|
'count': 2,
|
||||||
'category': "Jewelry"
|
'category': "Jewelry"
|
||||||
}
|
}
|
||||||
|
|
||||||
max_len = 5
|
max_len = 5
|
||||||
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
||||||
|
A = local_db.load_filtered_ids([
|
||||||
|
{"item_group_id": {"$ne": "Clothing"}},
|
||||||
|
{"item_group_id": {"$ne": "Shoes"}},
|
||||||
|
{"modality": "image"}
|
||||||
|
])
|
||||||
# print(db.random_get_accessories())
|
# print(db.random_get_accessories())
|
||||||
|
start_time = time.time()
|
||||||
query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
|
X = local_db.random_get_accessories(['ELI699_img'])
|
||||||
|
print(X)
|
||||||
results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
|
print(time.time() - start_time)
|
||||||
# 2. 从结果集中抽 stylist['count'] 个item
|
# query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
|
||||||
stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
|
#
|
||||||
stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
|
# results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
|
||||||
|
# # 2. 从结果集中抽 stylist['count'] 个item
|
||||||
# 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item
|
# stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
|
||||||
accessories_count = 9 - max_len - stylist['count']
|
# stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
|
||||||
|
#
|
||||||
random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
# # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item
|
||||||
random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
|
# accessories_count = 9 - max_len - stylist['count']
|
||||||
all_items = stylist_item + random_items
|
#
|
||||||
|
# random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||||
|
# random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||||
|
# all_items = stylist_item + random_items
|
||||||
|
|||||||
Reference in New Issue
Block a user