UPDATE: New retrieve algorithm. Solve multiple sungalsses issue. Only find item with score is positive.
This commit is contained in:
@@ -336,16 +336,16 @@ if __name__ == "__main__":
|
|||||||
request_data = json.load(f)
|
request_data = json.load(f)
|
||||||
|
|
||||||
tasks_with_metadata = []
|
tasks_with_metadata = []
|
||||||
for test_content in request_data[0:1]:
|
for test_content in request_data[14:20]:
|
||||||
occasions = test_content['occasions']
|
occasions = test_content['occasions']
|
||||||
request_summary = test_content['request_summary']
|
request_summary = test_content['request_summary']
|
||||||
for stylist_name in ["vera", "edi"]:
|
for stylist_name in ["edi"]:
|
||||||
stylist_agent_kwages['outfit_id'] = test_content['test_case_id'] + "_" + "_".join(occasions) + f"_{stylist_name}"
|
stylist_agent_kwages['outfit_id'] = test_content['test_case_id'] + "_" + "_".join(occasions) + f"_{stylist_name}"
|
||||||
stylist_agent_kwages['stylist_name'] = stylist_name
|
stylist_agent_kwages['stylist_name'] = stylist_name
|
||||||
stylist_agent_kwages['gender'] = "female"
|
stylist_agent_kwages['gender'] = "female"
|
||||||
agent = AsyncStylistAgent(**stylist_agent_kwages)
|
agent = AsyncStylistAgent(**stylist_agent_kwages)
|
||||||
coro = agent.run_iterative_styling(
|
# coro = agent.run_iterative_styling(
|
||||||
# coro = agent.run_quick_batch_styling(
|
coro = agent.run_quick_batch_styling(
|
||||||
request_summary=request_summary,
|
request_summary=request_summary,
|
||||||
occasions=occasions,
|
occasions=occasions,
|
||||||
start_outfit=[],
|
start_outfit=[],
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ class AsyncStylistAgent:
|
|||||||
results = self.local_db.get_matched_item(
|
results = self.local_db.get_matched_item(
|
||||||
query_embedding,
|
query_embedding,
|
||||||
category,
|
category,
|
||||||
|
subcategory,
|
||||||
occasions=occasions,
|
occasions=occasions,
|
||||||
batch_sources=batch_sources,
|
batch_sources=batch_sources,
|
||||||
gender=gender,
|
gender=gender,
|
||||||
@@ -388,6 +389,7 @@ class AsyncStylistAgent:
|
|||||||
gemini_data = self._parse_gemini_response(gemini_response_text)
|
gemini_data = self._parse_gemini_response(gemini_response_text)
|
||||||
recommended_items = gemini_data.get('recommended_items', [])
|
recommended_items = gemini_data.get('recommended_items', [])
|
||||||
reason = gemini_data.get('reason', '')
|
reason = gemini_data.get('reason', '')
|
||||||
|
failed_found_item_count = 0
|
||||||
|
|
||||||
if not recommended_items or not isinstance(recommended_items, List):
|
if not recommended_items or not isinstance(recommended_items, List):
|
||||||
print("No recommended item from Gemini, terminating process.")
|
print("No recommended item from Gemini, terminating process.")
|
||||||
@@ -419,10 +421,21 @@ class AsyncStylistAgent:
|
|||||||
|
|
||||||
new_item = self._get_next_item(description, category, subcategory, occasions, batch_sources, self.gender)
|
new_item = self._get_next_item(description, category, subcategory, occasions, batch_sources, self.gender)
|
||||||
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
||||||
|
failed_found_item_count += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
self.outfit_items.append(new_item)
|
self.outfit_items.append(new_item)
|
||||||
print(f"Item {idx + 1}: ({subcategory}) {rec_item}, found item: {new_item}")
|
print(f"Item {idx + 1}: ({subcategory}) {rec_item}, found item: {new_item}")
|
||||||
|
|
||||||
|
# 如果没有找到的item过于多,需要重试
|
||||||
|
if failed_found_item_count / len(recommended_items) > 0.5:
|
||||||
|
self.post_operation(
|
||||||
|
status="failed",
|
||||||
|
message=f"There are {failed_found_item_count} items (total {len(recommended_items)}) are not found in the database",
|
||||||
|
callback_url=url,
|
||||||
|
img_path=merged_image_path
|
||||||
|
)
|
||||||
|
print(f"There are {failed_found_item_count} items (total {len(recommended_items)}) are not found in the database")
|
||||||
return reason
|
return reason
|
||||||
|
|
||||||
async def run_iterative_styling(self, request_summary, occasions, start_outfit: Optional[List] = None, batch_sources: List = [], user_id="test", callback_url=""):
|
async def run_iterative_styling(self, request_summary, occasions, start_outfit: Optional[List] = None, batch_sources: List = [], user_id="test", callback_url=""):
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class VectorDatabase():
|
|||||||
|
|
||||||
return features.cpu().numpy().flatten().tolist()
|
return features.cpu().numpy().flatten().tolist()
|
||||||
|
|
||||||
def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
|
def get_matched_item(self, embedding: List[float], category: str, subcategory: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
|
||||||
if category not in CATEGORY_LIST:
|
if category not in CATEGORY_LIST:
|
||||||
raise ValueError(f"Recommended {category} is not valid.")
|
raise ValueError(f"Recommended {category} is not valid.")
|
||||||
|
|
||||||
@@ -86,10 +86,14 @@ class VectorDatabase():
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
metadatas = results['metadatas'][0] # List[Dict[str, Any]]
|
metadatas = results['metadatas'][0] # List[Dict[str, Any]]
|
||||||
final_scores = []
|
candidate_pool = []
|
||||||
|
all_scores = []
|
||||||
for idx, metadata in enumerate(metadatas):
|
for idx, metadata in enumerate(metadatas):
|
||||||
dist_img = results['distances'][0][idx]
|
dist_img = results['distances'][0][idx]
|
||||||
score_vec = 1 - dist_img # cosine similarity range: [-1, 1]
|
score_vec = 1 - dist_img # cosine similarity range: [-1, 1]
|
||||||
|
score_subcategory = 0.0
|
||||||
|
if subcategory == metadata['subcategory']:
|
||||||
|
score_subcategory = 1
|
||||||
|
|
||||||
score_occ = 0.0
|
score_occ = 0.0
|
||||||
if occasions:
|
if occasions:
|
||||||
@@ -100,26 +104,42 @@ class VectorDatabase():
|
|||||||
count += 1
|
count += 1
|
||||||
status_val = metadata.get(occ, -1)
|
status_val = metadata.get(occ, -1)
|
||||||
if status_val == 1:
|
if status_val == 1:
|
||||||
score_occ += 1.0
|
score_occ += 5.0
|
||||||
elif status_val == 0:
|
elif status_val == 0:
|
||||||
score_occ += 0.0
|
score_occ += 0.0
|
||||||
else:
|
else:
|
||||||
score_occ -= 100.0
|
score_occ -= 5.0
|
||||||
|
|
||||||
score_occ = score_occ / count if count else 0.0
|
score_occ = score_occ / count if count else 0.0
|
||||||
|
|
||||||
final_score = 0.6 * score_vec + 0.4 * score_occ
|
final_score = 0.5 * score_vec + 0.1 * score_occ + 0.5 * score_subcategory
|
||||||
final_scores.append(final_score)
|
all_scores.append(final_score)
|
||||||
|
if final_score > 0:
|
||||||
|
candidate_pool.append({
|
||||||
|
"score": final_score,
|
||||||
|
"metadata": metadata
|
||||||
|
})
|
||||||
|
|
||||||
scores_arr = np.array(final_scores)
|
if not candidate_pool:
|
||||||
temperature = 0.5
|
print(f"All scores are negative: Ignore")
|
||||||
scores_arr = scores_arr / temperature
|
return []
|
||||||
|
|
||||||
|
# 采取topk截断
|
||||||
|
candidate_pool.sort(key=lambda x: x['score'], reverse=True)
|
||||||
|
current_k = min(10, len(candidate_pool))
|
||||||
|
top_candidates = candidate_pool[:current_k]
|
||||||
|
top_scores = np.array([x['score'] for x in top_candidates])
|
||||||
|
|
||||||
|
#降低温度,使得选择稳定
|
||||||
|
temperature = 0.2
|
||||||
|
scaled_scores = top_scores / temperature
|
||||||
|
|
||||||
# Softmax: 将分数转换为概率
|
# Softmax: 将分数转换为概率
|
||||||
exp_scores = np.exp(scores_arr - np.max(scores_arr))
|
exp_scores = np.exp(scaled_scores - np.max(scaled_scores))
|
||||||
probabilities = exp_scores / np.sum(exp_scores)
|
probabilities = exp_scores / np.sum(exp_scores)
|
||||||
|
|
||||||
# 采样 (或直接取 Top 1)
|
# 采样 (或直接取 Top 1)
|
||||||
sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样
|
sampled_index = np.random.choice(a=current_k, p=probabilities, size=min(n_results, current_k), replace=False) # 不重复采样
|
||||||
sampled_items = [metadatas[i] for i in sampled_index]
|
sampled_items = [top_candidates[i]['metadata'] for i in sampled_index]
|
||||||
|
|
||||||
return sampled_items
|
return sampled_items
|
||||||
|
|||||||
Reference in New Issue
Block a user