UPDATE: add color constrain in vector database for wedding occasion avoiding black items.

This commit is contained in:
pangkaicheng
2026-01-07 17:26:44 +08:00
parent 773db4fcc3
commit 46793ba271
3 changed files with 9 additions and 14 deletions

View File

@@ -357,10 +357,10 @@ if __name__ == "__main__":
request_data = json.load(f) request_data = json.load(f)
tasks_with_metadata = [] tasks_with_metadata = []
for test_content in request_data[8:10]: for test_content in request_data[9:10]:
occasions = test_content['occasions'] occasions = test_content['occasions']
request_summary = test_content['request_summary'] request_summary = test_content['request_summary']
for stylist_name in ["edi"]: for stylist_name in ["crystal", "mini", "vera", "edi"]:
stylist_agent_kwages['outfit_id'] = test_content['test_case_id'] + '_' + test_content['occasions'][0].replace('/', '_') + f"_{stylist_name}" stylist_agent_kwages['outfit_id'] = test_content['test_case_id'] + '_' + test_content['occasions'][0].replace('/', '_') + f"_{stylist_name}"
stylist_agent_kwages['stylist_name'] = stylist_name stylist_agent_kwages['stylist_name'] = stylist_name
stylist_agent_kwages['gender'] = "female" stylist_agent_kwages['gender'] = "female"

View File

@@ -313,7 +313,7 @@ class AsyncStylistAgent:
if status in ['failed']: if status in ['failed']:
# 失败直接打印参数 不发送结果 # 失败直接打印参数 不发送结果
response_data['message'] = message response_data['message'] = message
logger.info(f"request data {json.dumps(response_data, ensure_ascii=False, indent=2)}") logger.error(f"request data {json.dumps(response_data, ensure_ascii=False, indent=2)}")
else: else:
response = post_request(url=callback_url, data=json.dumps(response_data)) response = post_request(url=callback_url, data=json.dumps(response_data))
logger.info(f"request data {json.dumps(response_data, ensure_ascii=False, indent=2)} | JAVA callback info -> status:{response.status_code} | message:{response.text}") logger.info(f"request data {json.dumps(response_data, ensure_ascii=False, indent=2)} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
@@ -646,11 +646,10 @@ class AsyncStylistAgent:
else: else:
self.post_operation( self.post_operation(
status="failed", status="failed",
message=f"Failed to assemble a complete outfit after {max_retries} attempts for {occasions[0]}.", message=f"Failed to assemble a complete outfit after {max_retries} attempts for {occasions[0]}. Current items: {self.outfit_items}. Subcategories required by this occasion is {allowed_subcategories}",
callback_url=url, callback_url=url,
img_path="" img_path=""
) )
logger.error(f"Failed to assemble a complete outfit after {max_retries} attempts for {occasions[0]}. Current items: {self.outfit_items}. Subcategories required by this occasion is {allowed_subcategories}")
raise Exception(f"Failed to assemble a complete outfit after {max_retries} attempts for {occasions[0]}. Current items: {self.outfit_items}. Subcategories required by this occasion is {allowed_subcategories}") raise Exception(f"Failed to assemble a complete outfit after {max_retries} attempts for {occasions[0]}. Current items: {self.outfit_items}. Subcategories required by this occasion is {allowed_subcategories}")
# 推荐即将完成 回调通知前端 # 推荐即将完成 回调通知前端

View File

@@ -67,16 +67,12 @@ class VectorDatabase():
if brand_strication: if brand_strication:
and_conditions.append({"brand": {"$in": BRAND_WHITELIST}}) and_conditions.append({"brand": {"$in": BRAND_WHITELIST}})
if batch_sources and len(batch_sources) > 0: # 加一条occasion限制婚礼不能穿黑色
if len(batch_sources) == 1: if any(o in ["Bridal / Wedding", "Beach / Swim"] for o in occasions):
and_conditions.append({"batch_source": batch_sources[0]}) and_conditions.append({"color": {"$nin": ["BLACK", "DARK GREY", "DARK BLUE", "NAVY"]}})
else:
source_conditions = []
for source in batch_sources:
source_conditions.append({"batch_source": source})
# 将 Batch Source 的 OR 子句添加到主 AND 条件中 if batch_sources and len(batch_sources) > 0:
and_conditions.append({"$or": source_conditions}) and_conditions.append({"batch_source": {"$in": batch_sources}})
results = self.collection.query( results = self.collection.query(
query_embeddings=[embedding], query_embeddings=[embedding],