新增 取消agent配饰(保留鞋子)推荐,改为默认随机配饰搭配 使用json文件补充stylist删除掉的必要配饰
This commit is contained in:
@@ -55,7 +55,7 @@ class AsyncStylistAgent:
|
||||
self.gcs_bucket = "lc_stylist_agent_outfit_items"
|
||||
self.minio_bucket = "lanecarford"
|
||||
|
||||
def _load_style_guide(self, path: str) -> str:
|
||||
def _load_style_guide(self, path: str):
|
||||
"""加载 markdown 风格指南内容。"""
|
||||
parts = path.split('/', 1)
|
||||
if len(parts) != 2:
|
||||
@@ -63,18 +63,24 @@ class AsyncStylistAgent:
|
||||
|
||||
bucket_name, object_name = parts
|
||||
try:
|
||||
# 1. 获取对象
|
||||
# 获取对象 读取内容
|
||||
response = minio_client.get_object(bucket_name, object_name)
|
||||
|
||||
# 2. 读取内容
|
||||
content_bytes = response.read()
|
||||
|
||||
# 3. 关闭连接
|
||||
response.close()
|
||||
response.release_conn()
|
||||
json_response = minio_client.get_object(bucket_name, object_name.replace('.md', '.json'))
|
||||
json_data = json_response.data
|
||||
|
||||
# 4. 解码并返回
|
||||
return content_bytes.decode('utf-8')
|
||||
# 关闭连接
|
||||
response.close()
|
||||
json_response.close()
|
||||
response.release_conn()
|
||||
json_response.release_conn()
|
||||
|
||||
# 4. 解析 JSON 字符串
|
||||
json_string = json_data.decode('utf-8')
|
||||
json_content = json.loads(json_string)
|
||||
|
||||
return content_bytes.decode('utf-8'), json_content
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load style guide from {path}: {e}")
|
||||
@@ -214,6 +220,33 @@ class AsyncStylistAgent:
|
||||
# 返回一个停止信号以防止循环继续
|
||||
return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"})
|
||||
|
||||
async def _merge_images(self, user_id: str):
|
||||
"""
|
||||
实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。
|
||||
|
||||
Args:
|
||||
user_input: 发送给模型的主文本内容。
|
||||
image_paths: 待发送图片的本地路径列表。
|
||||
|
||||
Returns:
|
||||
模型的响应文本(预期为 JSON 字符串)。
|
||||
"""
|
||||
minio_path = ""
|
||||
if self.outfit_items:
|
||||
merged_image = merge_images_to_square(self.outfit_items, max_len=9, add_text=False)
|
||||
image_bytes_io = io.BytesIO()
|
||||
image_format = 'JPEG'
|
||||
|
||||
merged_image.save(image_bytes_io, format=image_format)
|
||||
image_bytes = image_bytes_io.getvalue()
|
||||
|
||||
file_name = uuid.uuid4()
|
||||
blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg"
|
||||
responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes)
|
||||
minio_path = f"{responses.bucket_name}/{responses.object_name}"
|
||||
|
||||
return minio_path
|
||||
|
||||
def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""安全解析 Gemini 的 JSON 响应。"""
|
||||
try:
|
||||
@@ -260,20 +293,37 @@ class AsyncStylistAgent:
|
||||
print(f"An error occurred during item retrieval: {e}")
|
||||
return None
|
||||
|
||||
async def _get_random_accessories(self):
|
||||
results = self.local_db.random_get_accessories()
|
||||
async def _get_random_accessories(self, stylist):
|
||||
stylist_item = []
|
||||
stylist_item_ids = []
|
||||
for i in stylist:
|
||||
# 1. 根据stylist要求抽取item
|
||||
query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False)
|
||||
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
|
||||
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
|
||||
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
|
||||
|
||||
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
|
||||
best_meta = results['metadatas'][0]
|
||||
return {
|
||||
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
|
||||
"category": best_meta['category'],
|
||||
"gpt_description": best_meta['description'],
|
||||
'description': best_meta['description'],
|
||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||
# 这里假设 item_id 就是文件名的一部分
|
||||
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
|
||||
}
|
||||
accessories_count = 2
|
||||
|
||||
# 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item
|
||||
random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||
random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||
all_items = stylist_item + random_items
|
||||
|
||||
items_data = []
|
||||
|
||||
for best_meta in all_items:
|
||||
items_data.append({
|
||||
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
|
||||
"category": best_meta['category'],
|
||||
"gpt_description": best_meta['description'],
|
||||
'description': best_meta['description'],
|
||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||
# 这里假设 item_id 就是文件名的一部分
|
||||
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
|
||||
})
|
||||
|
||||
return items_data
|
||||
|
||||
def _build_user_input(self) -> str:
|
||||
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
||||
@@ -294,7 +344,7 @@ class AsyncStylistAgent:
|
||||
"""主流程控制循环。"""
|
||||
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
|
||||
|
||||
self.style_guide = self._load_style_guide(stylist_path)
|
||||
self.style_guide, self.style_accessories_guide = self._load_style_guide(stylist_path)
|
||||
self.system_prompt = self._build_system_prompt(request_summary, gender)
|
||||
response_data = {"status": "",
|
||||
"message": "",
|
||||
@@ -342,12 +392,13 @@ class AsyncStylistAgent:
|
||||
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||
|
||||
# 新增配饰
|
||||
new_item = await self._get_random_accessories()
|
||||
self.outfit_items.append(new_item)
|
||||
user_input = self._build_user_input()
|
||||
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
|
||||
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
|
||||
# 根据stylist要求随机增加配饰 3-4个配饰
|
||||
new_item = await self._get_random_accessories(self.style_accessories_guide)
|
||||
for item in new_item:
|
||||
self.outfit_items.append(item)
|
||||
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
|
||||
|
||||
response_data['path'] = await self._merge_images(user_id)
|
||||
|
||||
logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
|
||||
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
|
||||
@@ -424,12 +475,13 @@ class AsyncStylistAgent:
|
||||
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||
|
||||
# 新增配饰
|
||||
new_item = await self._get_random_accessories()
|
||||
self.outfit_items.append(new_item)
|
||||
user_input = self._build_user_input()
|
||||
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
|
||||
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
|
||||
# 根据stylist要求随机增加配饰 3-4个配饰
|
||||
new_item = await self._get_random_accessories(self.style_accessories_guide)
|
||||
for item in new_item:
|
||||
self.outfit_items.append(item)
|
||||
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
|
||||
response_data['path'] = await self._merge_images(user_id)
|
||||
|
||||
logger.info("🚨 达到最大搭配数量限制,强制终止。")
|
||||
self.stop_reason = "Finish reason: Reached max outfit length."
|
||||
response_data['status'] = "stop"
|
||||
|
||||
@@ -69,7 +69,7 @@ class VectorDatabase():
|
||||
"$and": [
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes
|
||||
{"modality": "image"},
|
||||
{"modality": "image"}
|
||||
]
|
||||
}
|
||||
MAX_LIMIT = 100000
|
||||
@@ -89,12 +89,11 @@ class VectorDatabase():
|
||||
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
def random_get_accessories(self):
|
||||
random_single_id = random.choice(self.cache_filtered_ids)
|
||||
def random_get_accessories(self, ids):
|
||||
# 2. 调用 ChromaDB:只查询这一个 ID 的详细信息
|
||||
try:
|
||||
final_results = self.collection.get(
|
||||
ids=[random_single_id],
|
||||
ids=ids,
|
||||
include=["metadatas"] # 你只需要元数据
|
||||
)
|
||||
|
||||
@@ -110,5 +109,25 @@ class VectorDatabase():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
||||
print(db.random_get_accessories())
|
||||
stylist = {
|
||||
'text': "gold necklace",
|
||||
'count': 2,
|
||||
'category': "Jewelry"
|
||||
}
|
||||
max_len = 5
|
||||
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
||||
# print(db.random_get_accessories())
|
||||
|
||||
query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
|
||||
|
||||
results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
|
||||
# 2. 从结果集中抽 stylist['count'] 个item
|
||||
stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
|
||||
stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
|
||||
|
||||
# 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item
|
||||
accessories_count = 9 - max_len - stylist['count']
|
||||
|
||||
random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||
random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||
all_items = stylist_item + random_items
|
||||
|
||||
@@ -114,12 +114,19 @@ if __name__ == '__main__':
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg"
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
|
||||
url = 'lanecarford/lc_stylist_agent_outfit_items/string/99cd8cc0-856a-487d-bb21-5684855ef48f.jpg'
|
||||
url = [
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/d9df7c48-c7e5-47f9-be67-07f0d175d202.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/ddf39b9c-69f0-4b28-95ed-9d823fa82e35.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/112194a0-dc1d-4151-8c58-82642142a553.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/788007f1-e44b-4390-ad9e-a2d4ba406379.jpg'
|
||||
]
|
||||
read_type = "1"
|
||||
img = oss_get_image(oss_client=minio_client, path=url, data_type=read_type)
|
||||
if read_type == "cv2":
|
||||
cv2.imshow("", img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
img.show()
|
||||
img.save("4.png")
|
||||
for id, i in enumerate(url):
|
||||
img = oss_get_image(minio_client, i, read_type)
|
||||
img = oss_get_image(oss_client=minio_client, path=i, data_type=read_type)
|
||||
if read_type == "cv2":
|
||||
cv2.imshow("", img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
img.show()
|
||||
img.save(f"{time.time()}.png")
|
||||
|
||||
@@ -16,4 +16,5 @@ pytorch-fid==0.3.0
|
||||
open-clip-torch==2.24.0
|
||||
pytorch-fid==0.3.0
|
||||
litserve
|
||||
# pip install git+https://github.com/openai/CLIP.git
|
||||
# pip install git+https://github.com/openai/CLIP.git
|
||||
# pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
Reference in New Issue
Block a user