diff --git a/app/server/ChatbotAgent/core/stylist_agent_server.py b/app/server/ChatbotAgent/core/stylist_agent_server.py index 13d3908..64d2202 100644 --- a/app/server/ChatbotAgent/core/stylist_agent_server.py +++ b/app/server/ChatbotAgent/core/stylist_agent_server.py @@ -55,7 +55,7 @@ class AsyncStylistAgent: self.gcs_bucket = "lc_stylist_agent_outfit_items" self.minio_bucket = "lanecarford" - def _load_style_guide(self, path: str) -> str: + def _load_style_guide(self, path: str): """加载 markdown 风格指南内容。""" parts = path.split('/', 1) if len(parts) != 2: @@ -63,18 +63,24 @@ class AsyncStylistAgent: bucket_name, object_name = parts try: - # 1. 获取对象 + # 获取对象 读取内容 response = minio_client.get_object(bucket_name, object_name) - - # 2. 读取内容 content_bytes = response.read() - # 3. 关闭连接 - response.close() - response.release_conn() + json_response = minio_client.get_object(bucket_name, object_name.replace('.md', '.json')) + json_data = json_response.data - # 4. 解码并返回 - return content_bytes.decode('utf-8') + # 关闭连接 + response.close() + json_response.close() + response.release_conn() + json_response.release_conn() + + # 4. 解析 JSON 字符串 + json_string = json_data.decode('utf-8') + json_content = json.loads(json_string) + + return content_bytes.decode('utf-8'), json_content except Exception as e: raise Exception(f"Failed to load style guide from {path}: {e}") @@ -214,6 +220,33 @@ class AsyncStylistAgent: # 返回一个停止信号以防止循环继续 return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"}) + async def _merge_images(self, user_id: str): + """ + 实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。 + + Args: + user_input: 发送给模型的主文本内容。 + image_paths: 待发送图片的本地路径列表。 + + Returns: + 模型的响应文本(预期为 JSON 字符串)。 + """ + minio_path = "" + if self.outfit_items: + merged_image = merge_images_to_square(self.outfit_items, max_len=9, add_text=False) + image_bytes_io = io.BytesIO() + image_format = 'JPEG' + + merged_image.save(image_bytes_io, format=image_format) + image_bytes = image_bytes_io.getvalue() + + file_name = uuid.uuid4() + blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg" + responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes) + minio_path = f"{responses.bucket_name}/{responses.object_name}" + + return minio_path + def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]: """安全解析 Gemini 的 JSON 响应。""" try: @@ -260,20 +293,37 @@ class AsyncStylistAgent: print(f"An error occurred during item retrieval: {e}") return None - async def _get_random_accessories(self): - results = self.local_db.random_get_accessories() + async def _get_random_accessories(self, stylist): + stylist_item = [] + stylist_item_ids = [] + for i in stylist: + # 1. 根据stylist要求抽取item + query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False) + stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10) + stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count']) + stylist_item_ids += [item_id['item_id'] for item_id in stylist_item] - # 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核) - best_meta = results['metadatas'][0] - return { - "item_id": best_meta['item_id'], # 从 metadata 字典中安全获取 - "category": best_meta['category'], - "gpt_description": best_meta['description'], - 'description': best_meta['description'], - # 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导 - # 这里假设 item_id 就是文件名的一部分 - "image_path": os.path.join(f"{best_meta['item_id']}.jpg") - } + accessories_count = 2 + + # 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item + random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) + random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas'] + all_items = stylist_item + random_items + + items_data = [] + + for best_meta in all_items: + items_data.append({ + "item_id": best_meta['item_id'], # 从 metadata 字典中安全获取 + "category": best_meta['category'], + "gpt_description": best_meta['description'], + 'description': best_meta['description'], + # 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导 + # 这里假设 item_id 就是文件名的一部分 + "image_path": os.path.join(f"{best_meta['item_id']}.jpg") + }) + + return items_data def _build_user_input(self) -> str: """构建发送给 Gemini 的用户输入,包含已选单品信息。""" @@ -294,7 +344,7 @@ class AsyncStylistAgent: """主流程控制循环。""" print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---") - self.style_guide = self._load_style_guide(stylist_path) + self.style_guide, self.style_accessories_guide = self._load_style_guide(stylist_path) self.system_prompt = self._build_system_prompt(request_summary, gender) response_data = {"status": "", "message": "", @@ -342,12 +392,13 @@ class AsyncStylistAgent: response = post_request(url=url, data=json.dumps(response_data), headers=headers) logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}") - # 新增配饰 - new_item = await self._get_random_accessories() - self.outfit_items.append(new_item) - user_input = self._build_user_input() - gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id) - response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')}) + # 根据stylist要求随机增加配饰 3-4个配饰 + new_item = await self._get_random_accessories(self.style_accessories_guide) + for item in new_item: + self.outfit_items.append(item) + response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')}) + + response_data['path'] = await self._merge_images(user_id) logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}") self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided') @@ -424,12 +475,13 @@ class AsyncStylistAgent: response = post_request(url=url, data=json.dumps(response_data), headers=headers) logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}") - # 新增配饰 - new_item = await self._get_random_accessories() - self.outfit_items.append(new_item) - user_input = self._build_user_input() - gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id) - response_data['items'].append({"item_id": new_item.get('item_id'), "category": new_item.get('category')}) + # 根据stylist要求随机增加配饰 3-4个配饰 + new_item = await self._get_random_accessories(self.style_accessories_guide) + for item in new_item: + self.outfit_items.append(item) + response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')}) + response_data['path'] = await self._merge_images(user_id) + logger.info("🚨 达到最大搭配数量限制,强制终止。") self.stop_reason = "Finish reason: Reached max outfit length." response_data['status'] = "stop" diff --git a/app/server/ChatbotAgent/core/vector_database.py b/app/server/ChatbotAgent/core/vector_database.py index 6461b5e..bf74a23 100644 --- a/app/server/ChatbotAgent/core/vector_database.py +++ b/app/server/ChatbotAgent/core/vector_database.py @@ -69,7 +69,7 @@ class VectorDatabase(): "$and": [ {"item_group_id": {"$ne": "Clothing"}}, {"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes - {"modality": "image"}, + {"modality": "image"} ] } MAX_LIMIT = 100000 @@ -89,12 +89,11 @@ class VectorDatabase(): print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}") return [] - def random_get_accessories(self): - random_single_id = random.choice(self.cache_filtered_ids) + def random_get_accessories(self, ids): # 2. 调用 ChromaDB:只查询这一个 ID 的详细信息 try: final_results = self.collection.get( - ids=[random_single_id], + ids=ids, include=["metadatas"] # 你只需要元数据 ) @@ -110,5 +109,25 @@ class VectorDatabase(): if __name__ == '__main__': - db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32") - print(db.random_get_accessories()) + stylist = { + 'text': "gold necklace", + 'count': 2, + 'category': "Jewelry" + } + max_len = 5 + local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32") + # print(db.random_get_accessories()) + + query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False) + + results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10) + # 2. 从结果集中抽 stylist['count'] 个item + stylist_item = random.choices(results['metadatas'][0], k=stylist['count']) + stylist_item_ids = [item_id['item_id'] for item_id in stylist_item] + + # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item + accessories_count = 9 - max_len - stylist['count'] + + random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) + random_items = local_db.random_get_accessories(random_single_ids)['metadatas'] + all_items = stylist_item + random_items diff --git a/app/server/utils/minio_client.py b/app/server/utils/minio_client.py index c2bfdb6..f2e5c5c 100644 --- a/app/server/utils/minio_client.py +++ b/app/server/utils/minio_client.py @@ -114,12 +114,19 @@ if __name__ == '__main__': # url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg" # url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg" # url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg" - url = 'lanecarford/lc_stylist_agent_outfit_items/string/99cd8cc0-856a-487d-bb21-5684855ef48f.jpg' + url = [ + 'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/d9df7c48-c7e5-47f9-be67-07f0d175d202.jpg', + 'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/ddf39b9c-69f0-4b28-95ed-9d823fa82e35.jpg', + 'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/112194a0-dc1d-4151-8c58-82642142a553.jpg', + 'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/788007f1-e44b-4390-ad9e-a2d4ba406379.jpg' + ] read_type = "1" - img = oss_get_image(oss_client=minio_client, path=url, data_type=read_type) - if read_type == "cv2": - cv2.imshow("", img) - cv2.waitKey(0) - else: - img.show() - img.save("4.png") + for id, i in enumerate(url): + img = oss_get_image(minio_client, i, read_type) + img = oss_get_image(oss_client=minio_client, path=i, data_type=read_type) + if read_type == "cv2": + cv2.imshow("", img) + cv2.waitKey(0) + else: + img.show() + img.save(f"{time.time()}.png") diff --git a/requirements.txt b/requirements.txt index 9a3c816..0e9483a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ pytorch-fid==0.3.0 open-clip-torch==2.24.0 pytorch-fid==0.3.0 litserve -# pip install git+https://github.com/openai/CLIP.git \ No newline at end of file +# pip install git+https://github.com/openai/CLIP.git +# pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 \ No newline at end of file