新增 取消agent配饰(保留鞋子)推荐,改为默认随机配饰搭配 使用json文件补充stylist删除掉的必要配饰
This commit is contained in:
@@ -55,7 +55,7 @@ class AsyncStylistAgent:
|
|||||||
self.gcs_bucket = "lc_stylist_agent_outfit_items"
|
self.gcs_bucket = "lc_stylist_agent_outfit_items"
|
||||||
self.minio_bucket = "lanecarford"
|
self.minio_bucket = "lanecarford"
|
||||||
|
|
||||||
def _load_style_guide(self, path: str) -> str:
|
def _load_style_guide(self, path: str):
|
||||||
"""加载 markdown 风格指南内容。"""
|
"""加载 markdown 风格指南内容。"""
|
||||||
parts = path.split('/', 1)
|
parts = path.split('/', 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
@@ -63,18 +63,24 @@ class AsyncStylistAgent:
|
|||||||
|
|
||||||
bucket_name, object_name = parts
|
bucket_name, object_name = parts
|
||||||
try:
|
try:
|
||||||
# 1. 获取对象
|
# 获取对象 读取内容
|
||||||
response = minio_client.get_object(bucket_name, object_name)
|
response = minio_client.get_object(bucket_name, object_name)
|
||||||
|
|
||||||
# 2. 读取内容
|
|
||||||
content_bytes = response.read()
|
content_bytes = response.read()
|
||||||
|
|
||||||
# 3. 关闭连接
|
json_response = minio_client.get_object(bucket_name, object_name.replace('.md', '.json'))
|
||||||
response.close()
|
json_data = json_response.data
|
||||||
response.release_conn()
|
|
||||||
|
|
||||||
# 4. 解码并返回
|
# 关闭连接
|
||||||
return content_bytes.decode('utf-8')
|
response.close()
|
||||||
|
json_response.close()
|
||||||
|
response.release_conn()
|
||||||
|
json_response.release_conn()
|
||||||
|
|
||||||
|
# 4. 解析 JSON 字符串
|
||||||
|
json_string = json_data.decode('utf-8')
|
||||||
|
json_content = json.loads(json_string)
|
||||||
|
|
||||||
|
return content_bytes.decode('utf-8'), json_content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Failed to load style guide from {path}: {e}")
|
raise Exception(f"Failed to load style guide from {path}: {e}")
|
||||||
@@ -214,6 +220,33 @@ class AsyncStylistAgent:
|
|||||||
# 返回一个停止信号以防止循环继续
|
# 返回一个停止信号以防止循环继续
|
||||||
return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"})
|
return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"})
|
||||||
|
|
||||||
|
async def _merge_images(self, user_id: str):
|
||||||
|
"""
|
||||||
|
实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: 发送给模型的主文本内容。
|
||||||
|
image_paths: 待发送图片的本地路径列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型的响应文本(预期为 JSON 字符串)。
|
||||||
|
"""
|
||||||
|
minio_path = ""
|
||||||
|
if self.outfit_items:
|
||||||
|
merged_image = merge_images_to_square(self.outfit_items, max_len=9, add_text=False)
|
||||||
|
image_bytes_io = io.BytesIO()
|
||||||
|
image_format = 'JPEG'
|
||||||
|
|
||||||
|
merged_image.save(image_bytes_io, format=image_format)
|
||||||
|
image_bytes = image_bytes_io.getvalue()
|
||||||
|
|
||||||
|
file_name = uuid.uuid4()
|
||||||
|
blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg"
|
||||||
|
responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes)
|
||||||
|
minio_path = f"{responses.bucket_name}/{responses.object_name}"
|
||||||
|
|
||||||
|
return minio_path
|
||||||
|
|
||||||
def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||||
"""安全解析 Gemini 的 JSON 响应。"""
|
"""安全解析 Gemini 的 JSON 响应。"""
|
||||||
try:
|
try:
|
||||||
@@ -260,12 +293,27 @@ class AsyncStylistAgent:
|
|||||||
print(f"An error occurred during item retrieval: {e}")
|
print(f"An error occurred during item retrieval: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _get_random_accessories(self):
|
async def _get_random_accessories(self, stylist):
|
||||||
results = self.local_db.random_get_accessories()
|
stylist_item = []
|
||||||
|
stylist_item_ids = []
|
||||||
|
for i in stylist:
|
||||||
|
# 1. 根据stylist要求抽取item
|
||||||
|
query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False)
|
||||||
|
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
|
||||||
|
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
|
||||||
|
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
|
||||||
|
|
||||||
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
|
accessories_count = 2
|
||||||
best_meta = results['metadatas'][0]
|
|
||||||
return {
|
# 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item
|
||||||
|
random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||||
|
random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||||
|
all_items = stylist_item + random_items
|
||||||
|
|
||||||
|
items_data = []
|
||||||
|
|
||||||
|
for best_meta in all_items:
|
||||||
|
items_data.append({
|
||||||
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
|
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
|
||||||
"category": best_meta['category'],
|
"category": best_meta['category'],
|
||||||
"gpt_description": best_meta['description'],
|
"gpt_description": best_meta['description'],
|
||||||
@@ -273,7 +321,9 @@ class AsyncStylistAgent:
|
|||||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||||
# 这里假设 item_id 就是文件名的一部分
|
# 这里假设 item_id 就是文件名的一部分
|
||||||
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
|
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
|
||||||
}
|
})
|
||||||
|
|
||||||
|
return items_data
|
||||||
|
|
||||||
def _build_user_input(self) -> str:
|
def _build_user_input(self) -> str:
|
||||||
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
||||||
@@ -294,7 +344,7 @@ class AsyncStylistAgent:
|
|||||||
"""主流程控制循环。"""
|
"""主流程控制循环。"""
|
||||||
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
|
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
|
||||||
|
|
||||||
self.style_guide = self._load_style_guide(stylist_path)
|
self.style_guide, self.style_accessories_guide = self._load_style_guide(stylist_path)
|
||||||
self.system_prompt = self._build_system_prompt(request_summary, gender)
|
self.system_prompt = self._build_system_prompt(request_summary, gender)
|
||||||
response_data = {"status": "",
|
response_data = {"status": "",
|
||||||
"message": "",
|
"message": "",
|
||||||
@@ -342,12 +392,13 @@ class AsyncStylistAgent:
|
|||||||
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
||||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||||
|
|
||||||
# 新增配饰
|
# 根据stylist要求随机增加配饰 3-4个配饰
|
||||||
new_item = await self._get_random_accessories()
|
new_item = await self._get_random_accessories(self.style_accessories_guide)
|
||||||
self.outfit_items.append(new_item)
|
for item in new_item:
|
||||||
user_input = self._build_user_input()
|
self.outfit_items.append(item)
|
||||||
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
|
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
|
||||||
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
|
|
||||||
|
response_data['path'] = await self._merge_images(user_id)
|
||||||
|
|
||||||
logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
|
logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
|
||||||
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
|
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
|
||||||
@@ -424,12 +475,13 @@ class AsyncStylistAgent:
|
|||||||
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
||||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||||
|
|
||||||
# 新增配饰
|
# 根据stylist要求随机增加配饰 3-4个配饰
|
||||||
new_item = await self._get_random_accessories()
|
new_item = await self._get_random_accessories(self.style_accessories_guide)
|
||||||
self.outfit_items.append(new_item)
|
for item in new_item:
|
||||||
user_input = self._build_user_input()
|
self.outfit_items.append(item)
|
||||||
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
|
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
|
||||||
response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')})
|
response_data['path'] = await self._merge_images(user_id)
|
||||||
|
|
||||||
logger.info("🚨 达到最大搭配数量限制,强制终止。")
|
logger.info("🚨 达到最大搭配数量限制,强制终止。")
|
||||||
self.stop_reason = "Finish reason: Reached max outfit length."
|
self.stop_reason = "Finish reason: Reached max outfit length."
|
||||||
response_data['status'] = "stop"
|
response_data['status'] = "stop"
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class VectorDatabase():
|
|||||||
"$and": [
|
"$and": [
|
||||||
{"item_group_id": {"$ne": "Clothing"}},
|
{"item_group_id": {"$ne": "Clothing"}},
|
||||||
{"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes
|
{"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes
|
||||||
{"modality": "image"},
|
{"modality": "image"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
MAX_LIMIT = 100000
|
MAX_LIMIT = 100000
|
||||||
@@ -89,12 +89,11 @@ class VectorDatabase():
|
|||||||
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
|
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def random_get_accessories(self):
|
def random_get_accessories(self, ids):
|
||||||
random_single_id = random.choice(self.cache_filtered_ids)
|
|
||||||
# 2. 调用 ChromaDB:只查询这一个 ID 的详细信息
|
# 2. 调用 ChromaDB:只查询这一个 ID 的详细信息
|
||||||
try:
|
try:
|
||||||
final_results = self.collection.get(
|
final_results = self.collection.get(
|
||||||
ids=[random_single_id],
|
ids=ids,
|
||||||
include=["metadatas"] # 你只需要元数据
|
include=["metadatas"] # 你只需要元数据
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -110,5 +109,25 @@ class VectorDatabase():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
stylist = {
|
||||||
print(db.random_get_accessories())
|
'text': "gold necklace",
|
||||||
|
'count': 2,
|
||||||
|
'category': "Jewelry"
|
||||||
|
}
|
||||||
|
max_len = 5
|
||||||
|
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
||||||
|
# print(db.random_get_accessories())
|
||||||
|
|
||||||
|
query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
|
||||||
|
|
||||||
|
results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
|
||||||
|
# 2. 从结果集中抽 stylist['count'] 个item
|
||||||
|
stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
|
||||||
|
stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
|
||||||
|
|
||||||
|
# 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item
|
||||||
|
accessories_count = 9 - max_len - stylist['count']
|
||||||
|
|
||||||
|
random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||||
|
random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||||
|
all_items = stylist_item + random_items
|
||||||
|
|||||||
@@ -114,12 +114,19 @@ if __name__ == '__main__':
|
|||||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg"
|
# url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg"
|
||||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
|
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
|
||||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
|
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
|
||||||
url = 'lanecarford/lc_stylist_agent_outfit_items/string/99cd8cc0-856a-487d-bb21-5684855ef48f.jpg'
|
url = [
|
||||||
|
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/d9df7c48-c7e5-47f9-be67-07f0d175d202.jpg',
|
||||||
|
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/ddf39b9c-69f0-4b28-95ed-9d823fa82e35.jpg',
|
||||||
|
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/112194a0-dc1d-4151-8c58-82642142a553.jpg',
|
||||||
|
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/788007f1-e44b-4390-ad9e-a2d4ba406379.jpg'
|
||||||
|
]
|
||||||
read_type = "1"
|
read_type = "1"
|
||||||
img = oss_get_image(oss_client=minio_client, path=url, data_type=read_type)
|
for id, i in enumerate(url):
|
||||||
|
img = oss_get_image(minio_client, i, read_type)
|
||||||
|
img = oss_get_image(oss_client=minio_client, path=i, data_type=read_type)
|
||||||
if read_type == "cv2":
|
if read_type == "cv2":
|
||||||
cv2.imshow("", img)
|
cv2.imshow("", img)
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
else:
|
else:
|
||||||
img.show()
|
img.show()
|
||||||
img.save("4.png")
|
img.save(f"{time.time()}.png")
|
||||||
|
|||||||
@@ -17,3 +17,4 @@ open-clip-torch==2.24.0
|
|||||||
pytorch-fid==0.3.0
|
pytorch-fid==0.3.0
|
||||||
litserve
|
litserve
|
||||||
# pip install git+https://github.com/openai/CLIP.git
|
# pip install git+https://github.com/openai/CLIP.git
|
||||||
|
# pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||||
Reference in New Issue
Block a user