diff --git a/app/server/ChatbotAgent/agent_server.py b/app/server/ChatbotAgent/agent_server.py index e157013..75e8340 100644 --- a/app/server/ChatbotAgent/agent_server.py +++ b/app/server/ChatbotAgent/agent_server.py @@ -336,16 +336,16 @@ if __name__ == "__main__": request_data = json.load(f) tasks_with_metadata = [] - for test_content in request_data[0:1]: + for test_content in request_data[14:20]: occasions = test_content['occasions'] request_summary = test_content['request_summary'] - for stylist_name in ["vera", "edi"]: + for stylist_name in ["edi"]: stylist_agent_kwages['outfit_id'] = test_content['test_case_id'] + "_" + "_".join(occasions) + f"_{stylist_name}" stylist_agent_kwages['stylist_name'] = stylist_name stylist_agent_kwages['gender'] = "female" agent = AsyncStylistAgent(**stylist_agent_kwages) - coro = agent.run_iterative_styling( - # coro = agent.run_quick_batch_styling( + # coro = agent.run_iterative_styling( + coro = agent.run_quick_batch_styling( request_summary=request_summary, occasions=occasions, start_outfit=[], diff --git a/app/server/ChatbotAgent/core/stylist_agent_server.py b/app/server/ChatbotAgent/core/stylist_agent_server.py index 57222a6..8c10cbc 100644 --- a/app/server/ChatbotAgent/core/stylist_agent_server.py +++ b/app/server/ChatbotAgent/core/stylist_agent_server.py @@ -178,6 +178,7 @@ class AsyncStylistAgent: results = self.local_db.get_matched_item( query_embedding, category, + subcategory, occasions=occasions, batch_sources=batch_sources, gender=gender, @@ -388,6 +389,7 @@ class AsyncStylistAgent: gemini_data = self._parse_gemini_response(gemini_response_text) recommended_items = gemini_data.get('recommended_items', []) reason = gemini_data.get('reason', '') + failed_found_item_count = 0 if not recommended_items or not isinstance(recommended_items, List): print("No recommended item from Gemini, terminating process.") @@ -419,10 +421,21 @@ class AsyncStylistAgent: new_item = self._get_next_item(description, category, subcategory, occasions, batch_sources, self.gender) if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]: + failed_found_item_count += 1 continue else: self.outfit_items.append(new_item) print(f"Item {idx + 1}: ({subcategory}) {rec_item}, found item: {new_item}") + + # 如果没有找到的item过于多,需要重试 + if failed_found_item_count / len(recommended_items) > 0.5: + self.post_operation( + status="failed", + message=f"There are {failed_found_item_count} items (total {len(recommended_items)}) are not found in the database", + callback_url=url, + img_path=merged_image_path + ) + print(f"There are {failed_found_item_count} items (total {len(recommended_items)}) are not found in the database") return reason async def run_iterative_styling(self, request_summary, occasions, start_outfit: Optional[List] = None, batch_sources: List = [], user_id="test", callback_url=""): diff --git a/app/server/ChatbotAgent/core/vector_database.py b/app/server/ChatbotAgent/core/vector_database.py index c051d99..4466afc 100644 --- a/app/server/ChatbotAgent/core/vector_database.py +++ b/app/server/ChatbotAgent/core/vector_database.py @@ -46,7 +46,7 @@ class VectorDatabase(): return features.cpu().numpy().flatten().tolist() - def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]: + def get_matched_item(self, embedding: List[float], category: str, subcategory: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]: if category not in CATEGORY_LIST: raise ValueError(f"Recommended {category} is not valid.") @@ -86,10 +86,14 @@ class VectorDatabase(): return [] metadatas = results['metadatas'][0] # List[Dict[str, Any]] - final_scores = [] + candidate_pool = [] + all_scores = [] for idx, metadata in enumerate(metadatas): dist_img = results['distances'][0][idx] score_vec = 1 - dist_img # cosine similarity range: [-1, 1] + score_subcategory = 0.0 + if subcategory == metadata['subcategory']: + score_subcategory = 1 score_occ = 0.0 if occasions: @@ -100,26 +104,42 @@ class VectorDatabase(): count += 1 status_val = metadata.get(occ, -1) if status_val == 1: - score_occ += 1.0 + score_occ += 5.0 elif status_val == 0: score_occ += 0.0 else: - score_occ -= 100.0 + score_occ -= 5.0 score_occ = score_occ / count if count else 0.0 - final_score = 0.6 * score_vec + 0.4 * score_occ - final_scores.append(final_score) + final_score = 0.5 * score_vec + 0.1 * score_occ + 0.5 * score_subcategory + all_scores.append(final_score) + if final_score > 0: + candidate_pool.append({ + "score": final_score, + "metadata": metadata + }) - scores_arr = np.array(final_scores) - temperature = 0.5 - scores_arr = scores_arr / temperature + if not candidate_pool: + print(f"All scores are negative: Ignore") + return [] + + # 采取topk截断 + candidate_pool.sort(key=lambda x: x['score'], reverse=True) + current_k = min(10, len(candidate_pool)) + top_candidates = candidate_pool[:current_k] + top_scores = np.array([x['score'] for x in top_candidates]) + + #降低温度,使得选择稳定 + temperature = 0.2 + scaled_scores = top_scores / temperature # Softmax: 将分数转换为概率 - exp_scores = np.exp(scores_arr - np.max(scores_arr)) + exp_scores = np.exp(scaled_scores - np.max(scaled_scores)) probabilities = exp_scores / np.sum(exp_scores) # 采样 (或直接取 Top 1) - sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样 - sampled_items = [metadatas[i] for i in sampled_index] + sampled_index = np.random.choice(a=current_k, p=probabilities, size=min(n_results, current_k), replace=False) # 不重复采样 + sampled_items = [top_candidates[i]['metadata'] for i in sampled_index] + return sampled_items