diff --git a/app/server/ChatbotAgent/core/stylist_agent_server.py b/app/server/ChatbotAgent/core/stylist_agent_server.py index 64d2202..38a6f4e 100644 --- a/app/server/ChatbotAgent/core/stylist_agent_server.py +++ b/app/server/ChatbotAgent/core/stylist_agent_server.py @@ -293,7 +293,7 @@ class AsyncStylistAgent: print(f"An error occurred during item retrieval: {e}") return None - async def _get_random_accessories(self, stylist): + async def _get_random_accessories(self, stylist, item_count): stylist_item = [] stylist_item_ids = [] for i in stylist: @@ -303,12 +303,16 @@ class AsyncStylistAgent: stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count']) stylist_item_ids += [item_id['item_id'] for item_id in stylist_item] - accessories_count = 2 + accessories_count = 9 - item_count - len(stylist_item) - # 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item - random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) - random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas'] - all_items = stylist_item + random_items + if accessories_count > 0: + accessories_count = random.randint(1, accessories_count) + # 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item + random_single_ids = random.choices(list(set(self.local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count) + random_items = self.local_db.random_get_accessories(random_single_ids)['metadatas'] + all_items = stylist_item + random_items + else: + all_items = stylist_item items_data = [] @@ -393,7 +397,7 @@ class AsyncStylistAgent: logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}") # 根据stylist要求随机增加配饰 3-4个配饰 - new_item = await self._get_random_accessories(self.style_accessories_guide) + new_item = await self._get_random_accessories(self.style_accessories_guide, len(self.outfit_items)) for item in new_item: self.outfit_items.append(item) response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')}) @@ -476,7 +480,7 @@ class AsyncStylistAgent: logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}") # 根据stylist要求随机增加配饰 3-4个配饰 - new_item = await self._get_random_accessories(self.style_accessories_guide) + new_item = await self._get_random_accessories(self.style_accessories_guide, len(self.outfit_items)) for item in new_item: self.outfit_items.append(item) response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})