UPDATE: New retrieve algorithm. Solve multiple sungalsses issue. Only find item with score is positive.

This commit is contained in:
pangkaicheng
2025-12-19 17:01:44 +08:00
parent 884e7966dd
commit 54aac900ad
3 changed files with 49 additions and 16 deletions

View File

@@ -336,16 +336,16 @@ if __name__ == "__main__":
request_data = json.load(f)
tasks_with_metadata = []
for test_content in request_data[0:1]:
for test_content in request_data[14:20]:
occasions = test_content['occasions']
request_summary = test_content['request_summary']
for stylist_name in ["vera", "edi"]:
for stylist_name in ["edi"]:
stylist_agent_kwages['outfit_id'] = test_content['test_case_id'] + "_" + "_".join(occasions) + f"_{stylist_name}"
stylist_agent_kwages['stylist_name'] = stylist_name
stylist_agent_kwages['gender'] = "female"
agent = AsyncStylistAgent(**stylist_agent_kwages)
coro = agent.run_iterative_styling(
# coro = agent.run_quick_batch_styling(
# coro = agent.run_iterative_styling(
coro = agent.run_quick_batch_styling(
request_summary=request_summary,
occasions=occasions,
start_outfit=[],

View File

@@ -178,6 +178,7 @@ class AsyncStylistAgent:
results = self.local_db.get_matched_item(
query_embedding,
category,
subcategory,
occasions=occasions,
batch_sources=batch_sources,
gender=gender,
@@ -388,6 +389,7 @@ class AsyncStylistAgent:
gemini_data = self._parse_gemini_response(gemini_response_text)
recommended_items = gemini_data.get('recommended_items', [])
reason = gemini_data.get('reason', '')
failed_found_item_count = 0
if not recommended_items or not isinstance(recommended_items, List):
print("No recommended item from Gemini, terminating process.")
@@ -419,10 +421,21 @@ class AsyncStylistAgent:
new_item = self._get_next_item(description, category, subcategory, occasions, batch_sources, self.gender)
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
failed_found_item_count += 1
continue
else:
self.outfit_items.append(new_item)
print(f"Item {idx + 1}: ({subcategory}) {rec_item}, found item: {new_item}")
# 如果没有找到的item过于多需要重试
if failed_found_item_count / len(recommended_items) > 0.5:
self.post_operation(
status="failed",
message=f"There are {failed_found_item_count} items (total {len(recommended_items)}) are not found in the database",
callback_url=url,
img_path=merged_image_path
)
print(f"There are {failed_found_item_count} items (total {len(recommended_items)}) are not found in the database")
return reason
async def run_iterative_styling(self, request_summary, occasions, start_outfit: Optional[List] = None, batch_sources: List = [], user_id="test", callback_url=""):

View File

@@ -46,7 +46,7 @@ class VectorDatabase():
return features.cpu().numpy().flatten().tolist()
def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
def get_matched_item(self, embedding: List[float], category: str, subcategory: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
if category not in CATEGORY_LIST:
raise ValueError(f"Recommended {category} is not valid.")
@@ -86,10 +86,14 @@ class VectorDatabase():
return []
metadatas = results['metadatas'][0] # List[Dict[str, Any]]
final_scores = []
candidate_pool = []
all_scores = []
for idx, metadata in enumerate(metadatas):
dist_img = results['distances'][0][idx]
score_vec = 1 - dist_img # cosine similarity range: [-1, 1]
score_subcategory = 0.0
if subcategory == metadata['subcategory']:
score_subcategory = 1
score_occ = 0.0
if occasions:
@@ -100,26 +104,42 @@ class VectorDatabase():
count += 1
status_val = metadata.get(occ, -1)
if status_val == 1:
score_occ += 1.0
score_occ += 5.0
elif status_val == 0:
score_occ += 0.0
else:
score_occ -= 100.0
score_occ -= 5.0
score_occ = score_occ / count if count else 0.0
final_score = 0.6 * score_vec + 0.4 * score_occ
final_scores.append(final_score)
final_score = 0.5 * score_vec + 0.1 * score_occ + 0.5 * score_subcategory
all_scores.append(final_score)
if final_score > 0:
candidate_pool.append({
"score": final_score,
"metadata": metadata
})
scores_arr = np.array(final_scores)
temperature = 0.5
scores_arr = scores_arr / temperature
if not candidate_pool:
print(f"All scores are negative: Ignore")
return []
# 采取topk截断
candidate_pool.sort(key=lambda x: x['score'], reverse=True)
current_k = min(10, len(candidate_pool))
top_candidates = candidate_pool[:current_k]
top_scores = np.array([x['score'] for x in top_candidates])
#降低温度,使得选择稳定
temperature = 0.2
scaled_scores = top_scores / temperature
# Softmax: 将分数转换为概率
exp_scores = np.exp(scores_arr - np.max(scores_arr))
exp_scores = np.exp(scaled_scores - np.max(scaled_scores))
probabilities = exp_scores / np.sum(exp_scores)
# 采样 (或直接取 Top 1)
sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样
sampled_items = [metadatas[i] for i in sampled_index]
sampled_index = np.random.choice(a=current_k, p=probabilities, size=min(n_results, current_k), replace=False) # 不重复采样
sampled_items = [top_candidates[i]['metadata'] for i in sampled_index]
return sampled_items