新增 取消agent配饰(保留鞋子)推荐,改为默认随机配饰搭配 使用json文件补充stylist删除掉的必要配饰

This commit is contained in:
zhh
2025-11-21 10:46:14 +08:00
parent e3d4939718
commit 2d38e3fc0b
4 changed files with 129 additions and 50 deletions

View File

@@ -55,7 +55,7 @@ class AsyncStylistAgent:
self.gcs_bucket = "lc_stylist_agent_outfit_items"
self.minio_bucket = "lanecarford"
def _load_style_guide(self, path: str) -> str:
def _load_style_guide(self, path: str):
"""加载 markdown 风格指南内容。"""
parts = path.split('/', 1)
if len(parts) != 2:
@@ -63,18 +63,24 @@ class AsyncStylistAgent:
bucket_name, object_name = parts
try:
# 1. 获取对象
# 获取对象 读取内容
response = minio_client.get_object(bucket_name, object_name)
# 2. 读取内容
content_bytes = response.read()
# 3. 关闭连接
response.close()
response.release_conn()
json_response = minio_client.get_object(bucket_name, object_name.replace('.md', '.json'))
json_data = json_response.data
# 4. 解码并返回
return content_bytes.decode('utf-8')
# 关闭连接
response.close()
json_response.close()
response.release_conn()
json_response.release_conn()
# 4. 解析 JSON 字符串
json_string = json_data.decode('utf-8')
json_content = json.loads(json_string)
return content_bytes.decode('utf-8'), json_content
except Exception as e:
raise Exception(f"Failed to load style guide from {path}: {e}")
@@ -214,6 +220,33 @@ class AsyncStylistAgent:
# 返回一个停止信号以防止循环继续
return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"})
async def _merge_images(self, user_id: str):
"""
实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。
Args:
user_input: 发送给模型的主文本内容。
image_paths: 待发送图片的本地路径列表。
Returns:
模型的响应文本(预期为 JSON 字符串)。
"""
minio_path = ""
if self.outfit_items:
merged_image = merge_images_to_square(self.outfit_items, max_len=9, add_text=False)
image_bytes_io = io.BytesIO()
image_format = 'JPEG'
merged_image.save(image_bytes_io, format=image_format)
image_bytes = image_bytes_io.getvalue()
file_name = uuid.uuid4()
blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg"
responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes)
minio_path = f"{responses.bucket_name}/{responses.object_name}"
return minio_path
def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]:
"""安全解析 Gemini 的 JSON 响应。"""
try:
@@ -260,20 +293,37 @@ class AsyncStylistAgent:
print(f"An error occurred during item retrieval: {e}")
return None
async def _get_random_accessories(self):
results = self.local_db.random_get_accessories()
async def _get_random_accessories(self, stylist):
stylist_item = []
stylist_item_ids = []
for i in stylist:
# 1. 根据stylist要求抽取item
query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False)
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
best_meta = results['metadatas'][0]
return {
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
"category": best_meta['category'],
"gpt_description": best_meta['description'],
'description': best_meta['description'],
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
# 这里假设 item_id 就是文件名的一部分
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
}
accessories_count = 2
# 2. 在配饰池中过滤掉已经选中的item 然后抽两个item
random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas']
all_items = stylist_item + random_items
items_data = []
for best_meta in all_items:
items_data.append({
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
"category": best_meta['category'],
"gpt_description": best_meta['description'],
'description': best_meta['description'],
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
# 这里假设 item_id 就是文件名的一部分
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
})
return items_data
def _build_user_input(self) -> str:
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
@@ -294,7 +344,7 @@ class AsyncStylistAgent:
"""主流程控制循环。"""
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
self.style_guide = self._load_style_guide(stylist_path)
self.style_guide, self.style_accessories_guide = self._load_style_guide(stylist_path)
self.system_prompt = self._build_system_prompt(request_summary, gender)
response_data = {"status": "",
"message": "",
@@ -342,12 +392,13 @@ class AsyncStylistAgent:
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
# 新增配饰
new_item = await self._get_random_accessories()
self.outfit_items.append(new_item)
user_input = self._build_user_input()
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
# 根据stylist要求随机增加配饰 3-4个配饰
new_item = await self._get_random_accessories(self.style_accessories_guide)
for item in new_item:
self.outfit_items.append(item)
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
response_data['path'] = await self._merge_images(user_id)
logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
@@ -424,12 +475,13 @@ class AsyncStylistAgent:
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
# 新增配饰
new_item = await self._get_random_accessories()
self.outfit_items.append(new_item)
user_input = self._build_user_input()
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
# 根据stylist要求随机增加配饰 3-4个配饰
new_item = await self._get_random_accessories(self.style_accessories_guide)
for item in new_item:
self.outfit_items.append(item)
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
response_data['path'] = await self._merge_images(user_id)
logger.info("🚨 达到最大搭配数量限制,强制终止。")
self.stop_reason = "Finish reason: Reached max outfit length."
response_data['status'] = "stop"

View File

@@ -69,7 +69,7 @@ class VectorDatabase():
"$and": [
{"item_group_id": {"$ne": "Clothing"}},
{"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes
{"modality": "image"},
{"modality": "image"}
]
}
MAX_LIMIT = 100000
@@ -89,12 +89,11 @@ class VectorDatabase():
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
return []
def random_get_accessories(self):
random_single_id = random.choice(self.cache_filtered_ids)
def random_get_accessories(self, ids):
# 2. 调用 ChromaDB只查询这一个 ID 的详细信息
try:
final_results = self.collection.get(
ids=[random_single_id],
ids=ids,
include=["metadatas"] # 你只需要元数据
)
@@ -110,5 +109,25 @@ class VectorDatabase():
if __name__ == '__main__':
db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
print(db.random_get_accessories())
stylist = {
'text': "gold necklace",
'count': 2,
'category': "Jewelry"
}
max_len = 5
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
# print(db.random_get_accessories())
query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
# 2. 从结果集中抽 stylist['count'] 个item
stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
# 3. 从随机库中抽取配饰总数达到9件 需过滤掉已经抽中的item
accessories_count = 9 - max_len - stylist['count']
random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
all_items = stylist_item + random_items

View File

@@ -114,12 +114,19 @@ if __name__ == '__main__':
# url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg"
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
url = 'lanecarford/lc_stylist_agent_outfit_items/string/99cd8cc0-856a-487d-bb21-5684855ef48f.jpg'
url = [
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/d9df7c48-c7e5-47f9-be67-07f0d175d202.jpg',
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/ddf39b9c-69f0-4b28-95ed-9d823fa82e35.jpg',
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/112194a0-dc1d-4151-8c58-82642142a553.jpg',
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/788007f1-e44b-4390-ad9e-a2d4ba406379.jpg'
]
read_type = "1"
img = oss_get_image(oss_client=minio_client, path=url, data_type=read_type)
if read_type == "cv2":
cv2.imshow("", img)
cv2.waitKey(0)
else:
img.show()
img.save("4.png")
for id, i in enumerate(url):
img = oss_get_image(minio_client, i, read_type)
img = oss_get_image(oss_client=minio_client, path=i, data_type=read_type)
if read_type == "cv2":
cv2.imshow("", img)
cv2.waitKey(0)
else:
img.show()
img.save(f"{time.time()}.png")

View File

@@ -16,4 +16,5 @@ pytorch-fid==0.3.0
open-clip-torch==2.24.0
pytorch-fid==0.3.0
litserve
# pip install git+https://github.com/openai/CLIP.git
# pip install git+https://github.com/openai/CLIP.git
# pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128