UPDATE: New retrieve algorithm. Solve multiple sungalsses issue. Only find item with score is positive.
This commit is contained in:
@@ -336,16 +336,16 @@ if __name__ == "__main__":
|
||||
request_data = json.load(f)
|
||||
|
||||
tasks_with_metadata = []
|
||||
for test_content in request_data[0:1]:
|
||||
for test_content in request_data[14:20]:
|
||||
occasions = test_content['occasions']
|
||||
request_summary = test_content['request_summary']
|
||||
for stylist_name in ["vera", "edi"]:
|
||||
for stylist_name in ["edi"]:
|
||||
stylist_agent_kwages['outfit_id'] = test_content['test_case_id'] + "_" + "_".join(occasions) + f"_{stylist_name}"
|
||||
stylist_agent_kwages['stylist_name'] = stylist_name
|
||||
stylist_agent_kwages['gender'] = "female"
|
||||
agent = AsyncStylistAgent(**stylist_agent_kwages)
|
||||
coro = agent.run_iterative_styling(
|
||||
# coro = agent.run_quick_batch_styling(
|
||||
# coro = agent.run_iterative_styling(
|
||||
coro = agent.run_quick_batch_styling(
|
||||
request_summary=request_summary,
|
||||
occasions=occasions,
|
||||
start_outfit=[],
|
||||
|
||||
@@ -178,6 +178,7 @@ class AsyncStylistAgent:
|
||||
results = self.local_db.get_matched_item(
|
||||
query_embedding,
|
||||
category,
|
||||
subcategory,
|
||||
occasions=occasions,
|
||||
batch_sources=batch_sources,
|
||||
gender=gender,
|
||||
@@ -388,6 +389,7 @@ class AsyncStylistAgent:
|
||||
gemini_data = self._parse_gemini_response(gemini_response_text)
|
||||
recommended_items = gemini_data.get('recommended_items', [])
|
||||
reason = gemini_data.get('reason', '')
|
||||
failed_found_item_count = 0
|
||||
|
||||
if not recommended_items or not isinstance(recommended_items, List):
|
||||
print("No recommended item from Gemini, terminating process.")
|
||||
@@ -419,10 +421,21 @@ class AsyncStylistAgent:
|
||||
|
||||
new_item = self._get_next_item(description, category, subcategory, occasions, batch_sources, self.gender)
|
||||
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
||||
failed_found_item_count += 1
|
||||
continue
|
||||
else:
|
||||
self.outfit_items.append(new_item)
|
||||
print(f"Item {idx + 1}: ({subcategory}) {rec_item}, found item: {new_item}")
|
||||
|
||||
# 如果没有找到的item过于多,需要重试
|
||||
if failed_found_item_count / len(recommended_items) > 0.5:
|
||||
self.post_operation(
|
||||
status="failed",
|
||||
message=f"There are {failed_found_item_count} items (total {len(recommended_items)}) are not found in the database",
|
||||
callback_url=url,
|
||||
img_path=merged_image_path
|
||||
)
|
||||
print(f"There are {failed_found_item_count} items (total {len(recommended_items)}) are not found in the database")
|
||||
return reason
|
||||
|
||||
async def run_iterative_styling(self, request_summary, occasions, start_outfit: Optional[List] = None, batch_sources: List = [], user_id="test", callback_url=""):
|
||||
|
||||
@@ -46,7 +46,7 @@ class VectorDatabase():
|
||||
|
||||
return features.cpu().numpy().flatten().tolist()
|
||||
|
||||
def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
|
||||
def get_matched_item(self, embedding: List[float], category: str, subcategory: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
|
||||
if category not in CATEGORY_LIST:
|
||||
raise ValueError(f"Recommended {category} is not valid.")
|
||||
|
||||
@@ -86,10 +86,14 @@ class VectorDatabase():
|
||||
return []
|
||||
|
||||
metadatas = results['metadatas'][0] # List[Dict[str, Any]]
|
||||
final_scores = []
|
||||
candidate_pool = []
|
||||
all_scores = []
|
||||
for idx, metadata in enumerate(metadatas):
|
||||
dist_img = results['distances'][0][idx]
|
||||
score_vec = 1 - dist_img # cosine similarity range: [-1, 1]
|
||||
score_subcategory = 0.0
|
||||
if subcategory == metadata['subcategory']:
|
||||
score_subcategory = 1
|
||||
|
||||
score_occ = 0.0
|
||||
if occasions:
|
||||
@@ -100,26 +104,42 @@ class VectorDatabase():
|
||||
count += 1
|
||||
status_val = metadata.get(occ, -1)
|
||||
if status_val == 1:
|
||||
score_occ += 1.0
|
||||
score_occ += 5.0
|
||||
elif status_val == 0:
|
||||
score_occ += 0.0
|
||||
else:
|
||||
score_occ -= 100.0
|
||||
score_occ -= 5.0
|
||||
|
||||
score_occ = score_occ / count if count else 0.0
|
||||
|
||||
final_score = 0.6 * score_vec + 0.4 * score_occ
|
||||
final_scores.append(final_score)
|
||||
final_score = 0.5 * score_vec + 0.1 * score_occ + 0.5 * score_subcategory
|
||||
all_scores.append(final_score)
|
||||
if final_score > 0:
|
||||
candidate_pool.append({
|
||||
"score": final_score,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
scores_arr = np.array(final_scores)
|
||||
temperature = 0.5
|
||||
scores_arr = scores_arr / temperature
|
||||
if not candidate_pool:
|
||||
print(f"All scores are negative: Ignore")
|
||||
return []
|
||||
|
||||
# 采取topk截断
|
||||
candidate_pool.sort(key=lambda x: x['score'], reverse=True)
|
||||
current_k = min(10, len(candidate_pool))
|
||||
top_candidates = candidate_pool[:current_k]
|
||||
top_scores = np.array([x['score'] for x in top_candidates])
|
||||
|
||||
#降低温度,使得选择稳定
|
||||
temperature = 0.2
|
||||
scaled_scores = top_scores / temperature
|
||||
|
||||
# Softmax: 将分数转换为概率
|
||||
exp_scores = np.exp(scores_arr - np.max(scores_arr))
|
||||
exp_scores = np.exp(scaled_scores - np.max(scaled_scores))
|
||||
probabilities = exp_scores / np.sum(exp_scores)
|
||||
|
||||
# 采样 (或直接取 Top 1)
|
||||
sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样
|
||||
sampled_items = [metadatas[i] for i in sampled_index]
|
||||
sampled_index = np.random.choice(a=current_k, p=probabilities, size=min(n_results, current_k), replace=False) # 不重复采样
|
||||
sampled_items = [top_candidates[i]['metadata'] for i in sampled_index]
|
||||
|
||||
return sampled_items
|
||||
|
||||
Reference in New Issue
Block a user