Enable data auto process for new data
This commit is contained in:
@@ -51,7 +51,7 @@ class AgentRequestModel(BaseModel):
|
||||
session_id: str
|
||||
num_outfits: int
|
||||
stylist_path: str
|
||||
batch_source: str
|
||||
batch_sources: List[str]
|
||||
callback_url: str
|
||||
gender: str
|
||||
max_len: int = 9
|
||||
@@ -112,8 +112,8 @@ class LCAgent(ls.LitAPI):
|
||||
request_summary=request_summary,
|
||||
occasions=occasions,
|
||||
stylist_name=request.stylist_path,
|
||||
batch_source=request.batch_source,
|
||||
start_outfit=[],
|
||||
batch_sources=request.batch_sources,
|
||||
num_outfits=request.num_outfits,
|
||||
user_id=request.user_id,
|
||||
gender=request.gender,
|
||||
@@ -162,7 +162,7 @@ class LCAgent(ls.LitAPI):
|
||||
return str(parsed_result.summary), [occ.value for occ in parsed_result.occasions]
|
||||
|
||||
async def recommend_outfit(
|
||||
self, request_summary: str, occasions: List[str], batch_source: str, stylist_name: str, start_outfit=[],
|
||||
self, request_summary: str, occasions: List[str], stylist_name: str, start_outfit: List = [], batch_sources: List[str] = [],
|
||||
num_outfits: int = 1, user_id: str = "test", gender: str = "male",
|
||||
callback_url: str = None, max_len: int = 9, outfit_ids=None
|
||||
):
|
||||
@@ -186,9 +186,9 @@ class LCAgent(ls.LitAPI):
|
||||
task = agent.run_styling_process(
|
||||
request_summary=request_summary,
|
||||
occasions=occasions,
|
||||
batch_source=batch_source,
|
||||
stylist_name=stylist_name,
|
||||
start_outfit=start_outfit,
|
||||
batch_sources=batch_sources,
|
||||
user_id=user_id,
|
||||
callback_url=callback_url,
|
||||
gender=gender,
|
||||
@@ -227,9 +227,9 @@ class LCAgent(ls.LitAPI):
|
||||
new_task = agent.run_styling_process(
|
||||
request_summary=request_summary,
|
||||
occasions=occasions,
|
||||
batch_source=batch_source,
|
||||
stylist_name=stylist_name,
|
||||
start_outfit=start_outfit,
|
||||
batch_sources=batch_sources,
|
||||
user_id=user_id,
|
||||
callback_url=callback_url
|
||||
)
|
||||
@@ -295,9 +295,9 @@ if __name__ == "__main__":
|
||||
task = agent.run_styling_process(
|
||||
request_summary=request_summary,
|
||||
occasions=occasions,
|
||||
batch_source="2025_q4",
|
||||
stylist_name=stylist_name,
|
||||
start_outfit=[],
|
||||
batch_sources=["2025_q4"],
|
||||
user_id=test_content['test_case_id'],
|
||||
callback_url="http://mock-callback.com/result",
|
||||
gender="female",
|
||||
|
||||
@@ -16,11 +16,16 @@ from app.server.utils.img_operation import merge_images_to_square
|
||||
from app.server.utils.minio_client import minio_client, oss_upload_image
|
||||
from app.server.utils.request_post import post_request
|
||||
from app.config import settings
|
||||
from app.taxonomy import CLOTHING_CATEGORY, ACCESSORY_CATEGORY
|
||||
from app.taxonomy import CATEGORY, ALL_CATEGORY, IGNORE_CATEGORY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
IGNORE_CATEGORY = set(IGNORE_CATEGORY)
|
||||
CLOTHING_CATEGORY = set(CATEGORY['clothing'] + CATEGORY['shoes'] + CATEGORY['bags']) - IGNORE_CATEGORY
|
||||
ACCESSORY_CATEGORY = set(CATEGORY['accessories']) - IGNORE_CATEGORY
|
||||
|
||||
|
||||
class AsyncStylistAgent:
|
||||
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
|
||||
# self.outfit_items: List[Dict[str, str]] = []
|
||||
@@ -145,7 +150,7 @@ class AsyncStylistAgent:
|
||||
```
|
||||
|
||||
* `action`: Must always be `"recommend_item"` until the outfit is complete.
|
||||
* `category`: Must be an unused category from the following list: {CLOTHING_CATEGORY} (strictly no repeats, per the Category Uniqueness Mandate).
|
||||
* `category`: Must be an unused category from the following list: {list(CLOTHING_CATEGORY)} (strictly no repeats, per the Category Uniqueness Mandate).
|
||||
* `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include:
|
||||
* **Color** (e.g., milk tea, pure white, dark gray)
|
||||
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
|
||||
@@ -193,7 +198,7 @@ class AsyncStylistAgent:
|
||||
---
|
||||
## STRICT RULES
|
||||
1. **Batch Recommendation**: Do NOT recommend items one by one. You must output the **COMPLETE LIST** of accessories (e.g., jewelry, bag, watch, hat) in a single response using the 'recommended_accessories' list.
|
||||
2. **Allowed Categories**: Select only from: {ACCESSORY_CATEGORY}.
|
||||
2. **Allowed Categories**: Select only from: {list(ACCESSORY_CATEGORY)}.
|
||||
3. **Harmony & Constraints**:
|
||||
- The accessories must complement the [Current Outfit Base].
|
||||
- Strictly follow the [Accessories Style Guide] regarding metals (gold/silver), numbers, and prohibited items.
|
||||
@@ -295,40 +300,46 @@ class AsyncStylistAgent:
|
||||
print(f"Raw response: {response_text}")
|
||||
return None
|
||||
|
||||
def _get_next_item(self, item_description: str, category: str, occasions: List[str], batch_source: str = "2025_q4", gender: str = "female") -> Optional[Dict[str, str]]:
|
||||
def _get_next_item(self, item_description: str, category: str, occasions: List[str], batch_sources: List[str] = [], gender: str = "female") -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
1. 根据描述生成嵌入。
|
||||
2. 查询本地数据库以找到最佳匹配项。
|
||||
3. 模拟 Agent 审核匹配项(这里简化为总是通过)。
|
||||
"""
|
||||
# 1. 生成查询嵌入
|
||||
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False)
|
||||
|
||||
# 2. 执行查询,并过滤类别
|
||||
try:
|
||||
# 1. 生成查询嵌入
|
||||
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False)
|
||||
results = self.local_db.get_matched_item(
|
||||
query_embedding,
|
||||
category,
|
||||
occasions=occasions,
|
||||
batch_sources=batch_sources,
|
||||
gender=gender,
|
||||
n_results=1
|
||||
)
|
||||
except ValueError as e:
|
||||
print(f"检测到无效参数错误:{e}")
|
||||
results = []
|
||||
|
||||
# 2. 执行查询,并过滤类别
|
||||
results = self.local_db.get_matched_item(query_embedding, category, occasions=occasions, batch_source=batch_source, gender=gender, n_results=1)
|
||||
|
||||
if not results:
|
||||
print(f"❌ 数据库中未找到符合 '{category}' 和描述的单品。")
|
||||
return None
|
||||
|
||||
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
|
||||
best_meta = results[0] # 第一个 batch 的第一个 metadata
|
||||
item_id = best_meta['item_id'].replace("_img", "")
|
||||
return {
|
||||
"item_id": item_id, # 从 metadata 字典中安全获取
|
||||
"category": best_meta['category'],
|
||||
"gpt_description": item_description,
|
||||
'description': best_meta['description'],
|
||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||
# 这里假设 item_id 就是文件名的一部分
|
||||
"image_path": os.path.join(f"{item_id}.jpg")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred during item retrieval: {e}")
|
||||
if not results:
|
||||
print(f"数据库中未找到符合 '{category}' 和描述的单品。")
|
||||
return None
|
||||
|
||||
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
|
||||
best_meta = results[0] # 第一个 batch 的第一个 metadata
|
||||
item_id = best_meta['item_id'].replace("_img", "")
|
||||
return {
|
||||
"item_id": item_id, # 从 metadata 字典中安全获取
|
||||
"category": best_meta['category'],
|
||||
"gpt_description": item_description,
|
||||
'description': best_meta['description'],
|
||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||
# 这里假设 item_id 就是文件名的一部分
|
||||
"image_path": os.path.join(f"{item_id}.jpg")
|
||||
}
|
||||
|
||||
def _build_user_input(self, recommend_acc=False) -> str:
|
||||
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
||||
if not self.outfit_items:
|
||||
@@ -353,7 +364,7 @@ class AsyncStylistAgent:
|
||||
response = post_request(url=callback_url, data=json.dumps(response_data), headers=self.headers)
|
||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||
|
||||
async def run_styling_process(self, request_summary, occasions, stylist_name, batch_source="2025_q4", start_outfit=[], user_id="test", callback_url="", gender: str = "male"):
|
||||
async def run_styling_process(self, request_summary, occasions, stylist_name, start_outfit=[], batch_sources=[], user_id="test", callback_url="", gender: str = "male"):
|
||||
self.outfit_items = start_outfit
|
||||
"""主流程控制循环。"""
|
||||
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
|
||||
@@ -412,7 +423,7 @@ class AsyncStylistAgent:
|
||||
continue
|
||||
|
||||
# 4b. 在本地 DB 中查询单品
|
||||
new_item = self._get_next_item(description, category, occasions, batch_source, gender)
|
||||
new_item = self._get_next_item(description, category, occasions, batch_sources, gender)
|
||||
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
||||
self.post_operation(
|
||||
response_data,
|
||||
@@ -460,7 +471,7 @@ class AsyncStylistAgent:
|
||||
continue
|
||||
|
||||
# 4b. 在本地 DB 中查询单品
|
||||
new_item = self._get_next_item(description, category, occasions, batch_source, gender)
|
||||
new_item = self._get_next_item(description, category, occasions, batch_sources, gender)
|
||||
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
||||
continue
|
||||
else:
|
||||
|
||||
@@ -8,13 +8,12 @@ from PIL import Image
|
||||
from typing import List, Dict, Any
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
from app.taxonomy import CATEGORY, OCCASION
|
||||
from app.taxonomy import OCCASION, ALL_CATEGORY
|
||||
|
||||
|
||||
class VectorDatabase():
|
||||
def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str):
|
||||
self.client = chromadb.PersistentClient(path=vector_db_dir)
|
||||
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
configuration={
|
||||
@@ -23,17 +22,9 @@ class VectorDatabase():
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
||||
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||
# self.cache_filtered_ids = self.load_filtered_ids([
|
||||
# {"item_group_id": {"$ne": "Clothing"}},
|
||||
# {"item_group_id": {"$ne": "Shoes"}},
|
||||
# {"modality": "image"}
|
||||
# ])
|
||||
# self.total_count = len(self.cache_filtered_ids)
|
||||
|
||||
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||
@@ -57,46 +48,32 @@ class VectorDatabase():
|
||||
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
return features.cpu().numpy().flatten().tolist()
|
||||
|
||||
def query_local_db(self, embedding: List[float], category: str, occasions: List[str] = [], n_results: int = 3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
基于嵌入向量在本地数据库中查询相似单品。
|
||||
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
|
||||
"""
|
||||
for occasion in occasions:
|
||||
where_clauses = {
|
||||
"$and": [
|
||||
{"category": category},
|
||||
{"modality": "image"},
|
||||
{"batch_source": '2025_q4'}
|
||||
]
|
||||
}
|
||||
if occasion not in OCCASION:
|
||||
continue
|
||||
else:
|
||||
where_clauses['$and'].append({occasion: 1})
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=[embedding],
|
||||
n_results=n_results,
|
||||
where=where_clauses,
|
||||
include=['metadatas', 'distances']
|
||||
)
|
||||
return results
|
||||
|
||||
def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_source: str = "2025_q4", gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
|
||||
def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
|
||||
if category not in ALL_CATEGORY:
|
||||
raise ValueError(f"Recommended {category} is not valid.")
|
||||
|
||||
and_conditions = [
|
||||
{"category": category},
|
||||
{"modality": "image"},
|
||||
{"$or": [
|
||||
{"gender": gender},
|
||||
{"gender": "unisex"},
|
||||
]}
|
||||
]
|
||||
if batch_sources and len(batch_sources) > 0:
|
||||
source_conditions = []
|
||||
for source in batch_sources:
|
||||
source_conditions.append({"batch_source": source})
|
||||
|
||||
# 将 Batch Source 的 OR 子句添加到主 AND 条件中
|
||||
and_conditions.append({"$or": source_conditions})
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=[embedding],
|
||||
n_results=500,
|
||||
where={
|
||||
"$and": [
|
||||
{"category": category},
|
||||
{"modality": "image"},
|
||||
{"gender": gender},
|
||||
{"batch_source": batch_source}
|
||||
]
|
||||
},
|
||||
include=['metadatas', 'distances']
|
||||
where={"$and": and_conditions},
|
||||
include=['metadatas', 'distances'],
|
||||
)
|
||||
if not results['ids'][0]:
|
||||
return []
|
||||
@@ -124,7 +101,7 @@ class VectorDatabase():
|
||||
|
||||
score_occ = score_occ / count if count else 0.0
|
||||
|
||||
final_score = 0.6 * score_vec + 0.3 * score_occ
|
||||
final_score = 0.6 * score_vec + 0.4 * score_occ
|
||||
final_scores.append(final_score)
|
||||
|
||||
scores_arr = np.array(final_scores)
|
||||
@@ -139,79 +116,3 @@ class VectorDatabase():
|
||||
sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样
|
||||
sampled_items = [metadatas[i] for i in sampled_index]
|
||||
return sampled_items
|
||||
|
||||
def load_filtered_ids(self, filter_item):
|
||||
# print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
||||
start_time = time.time()
|
||||
FILTER_CRITERIA = {
|
||||
"$and": filter_item
|
||||
}
|
||||
MAX_LIMIT = 100000
|
||||
|
||||
try:
|
||||
# 获取所有符合条件的 ID
|
||||
all_ids_results = self.collection.get(
|
||||
where=FILTER_CRITERIA,
|
||||
limit=MAX_LIMIT,
|
||||
include=[]
|
||||
)
|
||||
all_matched_ids = all_ids_results['ids']
|
||||
# print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
||||
print(time.time() - start_time)
|
||||
return all_matched_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
def random_get_accessories(self, ids):
|
||||
# 2. 调用 ChromaDB:只查询这一个 ID 的详细信息
|
||||
try:
|
||||
final_results = self.collection.get(
|
||||
ids=ids,
|
||||
include=["metadatas"] # 你只需要元数据
|
||||
)
|
||||
|
||||
# 提取结果
|
||||
if final_results['ids']:
|
||||
return final_results
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 获取最终记录时发生错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
stylist = {
|
||||
'text': "gold necklace",
|
||||
'count': 2,
|
||||
'category': "Jewelry"
|
||||
}
|
||||
|
||||
max_len = 5
|
||||
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
||||
A = local_db.load_filtered_ids([
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"item_group_id": {"$ne": "Shoes"}},
|
||||
{"modality": "image"}
|
||||
])
|
||||
# print(db.random_get_accessories())
|
||||
start_time = time.time()
|
||||
X = local_db.random_get_accessories(['ELI699_img'])
|
||||
print(X)
|
||||
print(time.time() - start_time)
|
||||
# query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
|
||||
#
|
||||
# results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
|
||||
# # 2. 从结果集中抽 stylist['count'] 个item
|
||||
# stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
|
||||
# stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
|
||||
#
|
||||
# # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item
|
||||
# accessories_count = 9 - max_len - stylist['count']
|
||||
#
|
||||
# random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||
# random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||
# all_items = stylist_item + random_items
|
||||
|
||||
@@ -1,15 +1,4 @@
|
||||
# 这个文件用来储存所有的category和occasion,这是标准文件。
|
||||
|
||||
CATEGORY = [
|
||||
'shoes', 'bags', 'dresses', 'tops', 'pants', 'skirts', 'outerwear', 'swimwear', 'suits',
|
||||
'watches', 'sunglasses', 'belts', 'hats', 'jewelry', 'neckties', 'scarves & shawls'
|
||||
]
|
||||
CLOTHING_CATEGORY = [
|
||||
'shoes', 'bags', 'dresses', 'tops', 'pants', 'skirts', 'outerwear', 'swimwear'
|
||||
]
|
||||
ACCESSORY_CATEGORY = [
|
||||
'watches', 'sunglasses', 'belts', 'hats', 'jewelry', 'neckties', 'scarves & shawls'
|
||||
]
|
||||
OCCASION = [
|
||||
"Casual", "Formal", "Activewear", "Resort", "Evening", "Outdoor",
|
||||
"Business / workwear", "Cocktail / Semi-Formal", "Black Tie / White Tie",
|
||||
@@ -17,3 +6,54 @@ OCCASION = [
|
||||
"Travel / Transit", "Athleisure", "Beach / Swim", "Ski / Snow / Mountain",
|
||||
"Garden Party / Daytime Event"
|
||||
]
|
||||
|
||||
CATEGORY = {
|
||||
'clothing': [
|
||||
'coats',
|
||||
'jackets',
|
||||
'blazers',
|
||||
'puffer',
|
||||
'cardigan',
|
||||
'sweater',
|
||||
'shirts',
|
||||
't-shirts',
|
||||
'pullover',
|
||||
'polos',
|
||||
'bodysuits',
|
||||
'dresses',
|
||||
'skirts',
|
||||
'jeans',
|
||||
'shorts',
|
||||
'leggings',
|
||||
'jumpsuits',
|
||||
'swimwear',
|
||||
],
|
||||
'shoes': [
|
||||
'sneakers',
|
||||
'formal shoes',
|
||||
'heels',
|
||||
'flats',
|
||||
'sandals',
|
||||
'slides',
|
||||
'boots',
|
||||
],
|
||||
'bags': [
|
||||
'bags'
|
||||
],
|
||||
'accessories': [
|
||||
'necklaces',
|
||||
'bracelets',
|
||||
'jewellery',
|
||||
'eyewear',
|
||||
'scarves',
|
||||
'hats',
|
||||
'gloves',
|
||||
'belts',
|
||||
'socks',
|
||||
'watches'
|
||||
'ties',
|
||||
]
|
||||
}
|
||||
ALL_CATEGORY = sum(CATEGORY.values(), [])
|
||||
|
||||
IGNORE_CATEGORY = ['socks']
|
||||
Reference in New Issue
Block a user