Enable data auto process for new data
This commit is contained in:
@@ -51,7 +51,7 @@ class AgentRequestModel(BaseModel):
|
|||||||
session_id: str
|
session_id: str
|
||||||
num_outfits: int
|
num_outfits: int
|
||||||
stylist_path: str
|
stylist_path: str
|
||||||
batch_source: str
|
batch_sources: List[str]
|
||||||
callback_url: str
|
callback_url: str
|
||||||
gender: str
|
gender: str
|
||||||
max_len: int = 9
|
max_len: int = 9
|
||||||
@@ -112,8 +112,8 @@ class LCAgent(ls.LitAPI):
|
|||||||
request_summary=request_summary,
|
request_summary=request_summary,
|
||||||
occasions=occasions,
|
occasions=occasions,
|
||||||
stylist_name=request.stylist_path,
|
stylist_name=request.stylist_path,
|
||||||
batch_source=request.batch_source,
|
|
||||||
start_outfit=[],
|
start_outfit=[],
|
||||||
|
batch_sources=request.batch_sources,
|
||||||
num_outfits=request.num_outfits,
|
num_outfits=request.num_outfits,
|
||||||
user_id=request.user_id,
|
user_id=request.user_id,
|
||||||
gender=request.gender,
|
gender=request.gender,
|
||||||
@@ -162,7 +162,7 @@ class LCAgent(ls.LitAPI):
|
|||||||
return str(parsed_result.summary), [occ.value for occ in parsed_result.occasions]
|
return str(parsed_result.summary), [occ.value for occ in parsed_result.occasions]
|
||||||
|
|
||||||
async def recommend_outfit(
|
async def recommend_outfit(
|
||||||
self, request_summary: str, occasions: List[str], batch_source: str, stylist_name: str, start_outfit=[],
|
self, request_summary: str, occasions: List[str], stylist_name: str, start_outfit: List = [], batch_sources: List[str] = [],
|
||||||
num_outfits: int = 1, user_id: str = "test", gender: str = "male",
|
num_outfits: int = 1, user_id: str = "test", gender: str = "male",
|
||||||
callback_url: str = None, max_len: int = 9, outfit_ids=None
|
callback_url: str = None, max_len: int = 9, outfit_ids=None
|
||||||
):
|
):
|
||||||
@@ -186,9 +186,9 @@ class LCAgent(ls.LitAPI):
|
|||||||
task = agent.run_styling_process(
|
task = agent.run_styling_process(
|
||||||
request_summary=request_summary,
|
request_summary=request_summary,
|
||||||
occasions=occasions,
|
occasions=occasions,
|
||||||
batch_source=batch_source,
|
|
||||||
stylist_name=stylist_name,
|
stylist_name=stylist_name,
|
||||||
start_outfit=start_outfit,
|
start_outfit=start_outfit,
|
||||||
|
batch_sources=batch_sources,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
callback_url=callback_url,
|
callback_url=callback_url,
|
||||||
gender=gender,
|
gender=gender,
|
||||||
@@ -227,9 +227,9 @@ class LCAgent(ls.LitAPI):
|
|||||||
new_task = agent.run_styling_process(
|
new_task = agent.run_styling_process(
|
||||||
request_summary=request_summary,
|
request_summary=request_summary,
|
||||||
occasions=occasions,
|
occasions=occasions,
|
||||||
batch_source=batch_source,
|
|
||||||
stylist_name=stylist_name,
|
stylist_name=stylist_name,
|
||||||
start_outfit=start_outfit,
|
start_outfit=start_outfit,
|
||||||
|
batch_sources=batch_sources,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
callback_url=callback_url
|
callback_url=callback_url
|
||||||
)
|
)
|
||||||
@@ -295,9 +295,9 @@ if __name__ == "__main__":
|
|||||||
task = agent.run_styling_process(
|
task = agent.run_styling_process(
|
||||||
request_summary=request_summary,
|
request_summary=request_summary,
|
||||||
occasions=occasions,
|
occasions=occasions,
|
||||||
batch_source="2025_q4",
|
|
||||||
stylist_name=stylist_name,
|
stylist_name=stylist_name,
|
||||||
start_outfit=[],
|
start_outfit=[],
|
||||||
|
batch_sources=["2025_q4"],
|
||||||
user_id=test_content['test_case_id'],
|
user_id=test_content['test_case_id'],
|
||||||
callback_url="http://mock-callback.com/result",
|
callback_url="http://mock-callback.com/result",
|
||||||
gender="female",
|
gender="female",
|
||||||
|
|||||||
@@ -16,11 +16,16 @@ from app.server.utils.img_operation import merge_images_to_square
|
|||||||
from app.server.utils.minio_client import minio_client, oss_upload_image
|
from app.server.utils.minio_client import minio_client, oss_upload_image
|
||||||
from app.server.utils.request_post import post_request
|
from app.server.utils.request_post import post_request
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.taxonomy import CLOTHING_CATEGORY, ACCESSORY_CATEGORY
|
from app.taxonomy import CATEGORY, ALL_CATEGORY, IGNORE_CATEGORY
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
IGNORE_CATEGORY = set(IGNORE_CATEGORY)
|
||||||
|
CLOTHING_CATEGORY = set(CATEGORY['clothing'] + CATEGORY['shoes'] + CATEGORY['bags']) - IGNORE_CATEGORY
|
||||||
|
ACCESSORY_CATEGORY = set(CATEGORY['accessories']) - IGNORE_CATEGORY
|
||||||
|
|
||||||
|
|
||||||
class AsyncStylistAgent:
|
class AsyncStylistAgent:
|
||||||
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
|
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
|
||||||
# self.outfit_items: List[Dict[str, str]] = []
|
# self.outfit_items: List[Dict[str, str]] = []
|
||||||
@@ -145,7 +150,7 @@ class AsyncStylistAgent:
|
|||||||
```
|
```
|
||||||
|
|
||||||
* `action`: Must always be `"recommend_item"` until the outfit is complete.
|
* `action`: Must always be `"recommend_item"` until the outfit is complete.
|
||||||
* `category`: Must be an unused category from the following list: {CLOTHING_CATEGORY} (strictly no repeats, per the Category Uniqueness Mandate).
|
* `category`: Must be an unused category from the following list: {list(CLOTHING_CATEGORY)} (strictly no repeats, per the Category Uniqueness Mandate).
|
||||||
* `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include:
|
* `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include:
|
||||||
* **Color** (e.g., milk tea, pure white, dark gray)
|
* **Color** (e.g., milk tea, pure white, dark gray)
|
||||||
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
|
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
|
||||||
@@ -193,7 +198,7 @@ class AsyncStylistAgent:
|
|||||||
---
|
---
|
||||||
## STRICT RULES
|
## STRICT RULES
|
||||||
1. **Batch Recommendation**: Do NOT recommend items one by one. You must output the **COMPLETE LIST** of accessories (e.g., jewelry, bag, watch, hat) in a single response using the 'recommended_accessories' list.
|
1. **Batch Recommendation**: Do NOT recommend items one by one. You must output the **COMPLETE LIST** of accessories (e.g., jewelry, bag, watch, hat) in a single response using the 'recommended_accessories' list.
|
||||||
2. **Allowed Categories**: Select only from: {ACCESSORY_CATEGORY}.
|
2. **Allowed Categories**: Select only from: {list(ACCESSORY_CATEGORY)}.
|
||||||
3. **Harmony & Constraints**:
|
3. **Harmony & Constraints**:
|
||||||
- The accessories must complement the [Current Outfit Base].
|
- The accessories must complement the [Current Outfit Base].
|
||||||
- Strictly follow the [Accessories Style Guide] regarding metals (gold/silver), numbers, and prohibited items.
|
- Strictly follow the [Accessories Style Guide] regarding metals (gold/silver), numbers, and prohibited items.
|
||||||
@@ -295,40 +300,46 @@ class AsyncStylistAgent:
|
|||||||
print(f"Raw response: {response_text}")
|
print(f"Raw response: {response_text}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_next_item(self, item_description: str, category: str, occasions: List[str], batch_source: str = "2025_q4", gender: str = "female") -> Optional[Dict[str, str]]:
|
def _get_next_item(self, item_description: str, category: str, occasions: List[str], batch_sources: List[str] = [], gender: str = "female") -> Optional[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
1. 根据描述生成嵌入。
|
1. 根据描述生成嵌入。
|
||||||
2. 查询本地数据库以找到最佳匹配项。
|
2. 查询本地数据库以找到最佳匹配项。
|
||||||
3. 模拟 Agent 审核匹配项(这里简化为总是通过)。
|
3. 模拟 Agent 审核匹配项(这里简化为总是通过)。
|
||||||
"""
|
"""
|
||||||
|
# 1. 生成查询嵌入
|
||||||
|
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False)
|
||||||
|
|
||||||
|
# 2. 执行查询,并过滤类别
|
||||||
try:
|
try:
|
||||||
# 1. 生成查询嵌入
|
results = self.local_db.get_matched_item(
|
||||||
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False)
|
query_embedding,
|
||||||
|
category,
|
||||||
|
occasions=occasions,
|
||||||
|
batch_sources=batch_sources,
|
||||||
|
gender=gender,
|
||||||
|
n_results=1
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"检测到无效参数错误:{e}")
|
||||||
|
results = []
|
||||||
|
|
||||||
# 2. 执行查询,并过滤类别
|
if not results:
|
||||||
results = self.local_db.get_matched_item(query_embedding, category, occasions=occasions, batch_source=batch_source, gender=gender, n_results=1)
|
print(f"数据库中未找到符合 '{category}' 和描述的单品。")
|
||||||
|
|
||||||
if not results:
|
|
||||||
print(f"❌ 数据库中未找到符合 '{category}' 和描述的单品。")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
|
|
||||||
best_meta = results[0] # 第一个 batch 的第一个 metadata
|
|
||||||
item_id = best_meta['item_id'].replace("_img", "")
|
|
||||||
return {
|
|
||||||
"item_id": item_id, # 从 metadata 字典中安全获取
|
|
||||||
"category": best_meta['category'],
|
|
||||||
"gpt_description": item_description,
|
|
||||||
'description': best_meta['description'],
|
|
||||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
|
||||||
# 这里假设 item_id 就是文件名的一部分
|
|
||||||
"image_path": os.path.join(f"{item_id}.jpg")
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred during item retrieval: {e}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
|
||||||
|
best_meta = results[0] # 第一个 batch 的第一个 metadata
|
||||||
|
item_id = best_meta['item_id'].replace("_img", "")
|
||||||
|
return {
|
||||||
|
"item_id": item_id, # 从 metadata 字典中安全获取
|
||||||
|
"category": best_meta['category'],
|
||||||
|
"gpt_description": item_description,
|
||||||
|
'description': best_meta['description'],
|
||||||
|
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||||
|
# 这里假设 item_id 就是文件名的一部分
|
||||||
|
"image_path": os.path.join(f"{item_id}.jpg")
|
||||||
|
}
|
||||||
|
|
||||||
def _build_user_input(self, recommend_acc=False) -> str:
|
def _build_user_input(self, recommend_acc=False) -> str:
|
||||||
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
||||||
if not self.outfit_items:
|
if not self.outfit_items:
|
||||||
@@ -353,7 +364,7 @@ class AsyncStylistAgent:
|
|||||||
response = post_request(url=callback_url, data=json.dumps(response_data), headers=self.headers)
|
response = post_request(url=callback_url, data=json.dumps(response_data), headers=self.headers)
|
||||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||||
|
|
||||||
async def run_styling_process(self, request_summary, occasions, stylist_name, batch_source="2025_q4", start_outfit=[], user_id="test", callback_url="", gender: str = "male"):
|
async def run_styling_process(self, request_summary, occasions, stylist_name, start_outfit=[], batch_sources=[], user_id="test", callback_url="", gender: str = "male"):
|
||||||
self.outfit_items = start_outfit
|
self.outfit_items = start_outfit
|
||||||
"""主流程控制循环。"""
|
"""主流程控制循环。"""
|
||||||
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
|
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
|
||||||
@@ -412,7 +423,7 @@ class AsyncStylistAgent:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 4b. 在本地 DB 中查询单品
|
# 4b. 在本地 DB 中查询单品
|
||||||
new_item = self._get_next_item(description, category, occasions, batch_source, gender)
|
new_item = self._get_next_item(description, category, occasions, batch_sources, gender)
|
||||||
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
||||||
self.post_operation(
|
self.post_operation(
|
||||||
response_data,
|
response_data,
|
||||||
@@ -460,7 +471,7 @@ class AsyncStylistAgent:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 4b. 在本地 DB 中查询单品
|
# 4b. 在本地 DB 中查询单品
|
||||||
new_item = self._get_next_item(description, category, occasions, batch_source, gender)
|
new_item = self._get_next_item(description, category, occasions, batch_sources, gender)
|
||||||
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -8,13 +8,12 @@ from PIL import Image
|
|||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from transformers import CLIPProcessor, CLIPModel
|
from transformers import CLIPProcessor, CLIPModel
|
||||||
|
|
||||||
from app.taxonomy import CATEGORY, OCCASION
|
from app.taxonomy import OCCASION, ALL_CATEGORY
|
||||||
|
|
||||||
|
|
||||||
class VectorDatabase():
|
class VectorDatabase():
|
||||||
def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str):
|
def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str):
|
||||||
self.client = chromadb.PersistentClient(path=vector_db_dir)
|
self.client = chromadb.PersistentClient(path=vector_db_dir)
|
||||||
|
|
||||||
self.collection = self.client.get_or_create_collection(
|
self.collection = self.client.get_or_create_collection(
|
||||||
name=collection_name,
|
name=collection_name,
|
||||||
configuration={
|
configuration={
|
||||||
@@ -23,17 +22,9 @@ class VectorDatabase():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
||||||
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||||
# self.cache_filtered_ids = self.load_filtered_ids([
|
|
||||||
# {"item_group_id": {"$ne": "Clothing"}},
|
|
||||||
# {"item_group_id": {"$ne": "Shoes"}},
|
|
||||||
# {"modality": "image"}
|
|
||||||
# ])
|
|
||||||
# self.total_count = len(self.cache_filtered_ids)
|
|
||||||
|
|
||||||
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
||||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||||
@@ -58,45 +49,31 @@ class VectorDatabase():
|
|||||||
|
|
||||||
return features.cpu().numpy().flatten().tolist()
|
return features.cpu().numpy().flatten().tolist()
|
||||||
|
|
||||||
def query_local_db(self, embedding: List[float], category: str, occasions: List[str] = [], n_results: int = 3) -> List[Dict[str, Any]]:
|
def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
|
||||||
"""
|
if category not in ALL_CATEGORY:
|
||||||
基于嵌入向量在本地数据库中查询相似单品。
|
raise ValueError(f"Recommended {category} is not valid.")
|
||||||
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
|
|
||||||
"""
|
|
||||||
for occasion in occasions:
|
|
||||||
where_clauses = {
|
|
||||||
"$and": [
|
|
||||||
{"category": category},
|
|
||||||
{"modality": "image"},
|
|
||||||
{"batch_source": '2025_q4'}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
if occasion not in OCCASION:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
where_clauses['$and'].append({occasion: 1})
|
|
||||||
|
|
||||||
results = self.collection.query(
|
and_conditions = [
|
||||||
query_embeddings=[embedding],
|
{"category": category},
|
||||||
n_results=n_results,
|
{"modality": "image"},
|
||||||
where=where_clauses,
|
{"$or": [
|
||||||
include=['metadatas', 'distances']
|
{"gender": gender},
|
||||||
)
|
{"gender": "unisex"},
|
||||||
return results
|
]}
|
||||||
|
]
|
||||||
|
if batch_sources and len(batch_sources) > 0:
|
||||||
|
source_conditions = []
|
||||||
|
for source in batch_sources:
|
||||||
|
source_conditions.append({"batch_source": source})
|
||||||
|
|
||||||
|
# 将 Batch Source 的 OR 子句添加到主 AND 条件中
|
||||||
|
and_conditions.append({"$or": source_conditions})
|
||||||
|
|
||||||
def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_source: str = "2025_q4", gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
|
|
||||||
results = self.collection.query(
|
results = self.collection.query(
|
||||||
query_embeddings=[embedding],
|
query_embeddings=[embedding],
|
||||||
n_results=500,
|
n_results=500,
|
||||||
where={
|
where={"$and": and_conditions},
|
||||||
"$and": [
|
include=['metadatas', 'distances'],
|
||||||
{"category": category},
|
|
||||||
{"modality": "image"},
|
|
||||||
{"gender": gender},
|
|
||||||
{"batch_source": batch_source}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
include=['metadatas', 'distances']
|
|
||||||
)
|
)
|
||||||
if not results['ids'][0]:
|
if not results['ids'][0]:
|
||||||
return []
|
return []
|
||||||
@@ -124,7 +101,7 @@ class VectorDatabase():
|
|||||||
|
|
||||||
score_occ = score_occ / count if count else 0.0
|
score_occ = score_occ / count if count else 0.0
|
||||||
|
|
||||||
final_score = 0.6 * score_vec + 0.3 * score_occ
|
final_score = 0.6 * score_vec + 0.4 * score_occ
|
||||||
final_scores.append(final_score)
|
final_scores.append(final_score)
|
||||||
|
|
||||||
scores_arr = np.array(final_scores)
|
scores_arr = np.array(final_scores)
|
||||||
@@ -139,79 +116,3 @@ class VectorDatabase():
|
|||||||
sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样
|
sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样
|
||||||
sampled_items = [metadatas[i] for i in sampled_index]
|
sampled_items = [metadatas[i] for i in sampled_index]
|
||||||
return sampled_items
|
return sampled_items
|
||||||
|
|
||||||
def load_filtered_ids(self, filter_item):
|
|
||||||
# print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
|
||||||
start_time = time.time()
|
|
||||||
FILTER_CRITERIA = {
|
|
||||||
"$and": filter_item
|
|
||||||
}
|
|
||||||
MAX_LIMIT = 100000
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取所有符合条件的 ID
|
|
||||||
all_ids_results = self.collection.get(
|
|
||||||
where=FILTER_CRITERIA,
|
|
||||||
limit=MAX_LIMIT,
|
|
||||||
include=[]
|
|
||||||
)
|
|
||||||
all_matched_ids = all_ids_results['ids']
|
|
||||||
# print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
|
||||||
print(time.time() - start_time)
|
|
||||||
return all_matched_ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def random_get_accessories(self, ids):
|
|
||||||
# 2. 调用 ChromaDB:只查询这一个 ID 的详细信息
|
|
||||||
try:
|
|
||||||
final_results = self.collection.get(
|
|
||||||
ids=ids,
|
|
||||||
include=["metadatas"] # 你只需要元数据
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提取结果
|
|
||||||
if final_results['ids']:
|
|
||||||
return final_results
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 获取最终记录时发生错误: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
stylist = {
|
|
||||||
'text': "gold necklace",
|
|
||||||
'count': 2,
|
|
||||||
'category': "Jewelry"
|
|
||||||
}
|
|
||||||
|
|
||||||
max_len = 5
|
|
||||||
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
|
||||||
A = local_db.load_filtered_ids([
|
|
||||||
{"item_group_id": {"$ne": "Clothing"}},
|
|
||||||
{"item_group_id": {"$ne": "Shoes"}},
|
|
||||||
{"modality": "image"}
|
|
||||||
])
|
|
||||||
# print(db.random_get_accessories())
|
|
||||||
start_time = time.time()
|
|
||||||
X = local_db.random_get_accessories(['ELI699_img'])
|
|
||||||
print(X)
|
|
||||||
print(time.time() - start_time)
|
|
||||||
# query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
|
|
||||||
#
|
|
||||||
# results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
|
|
||||||
# # 2. 从结果集中抽 stylist['count'] 个item
|
|
||||||
# stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
|
|
||||||
# stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
|
|
||||||
#
|
|
||||||
# # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item
|
|
||||||
# accessories_count = 9 - max_len - stylist['count']
|
|
||||||
#
|
|
||||||
# random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
|
||||||
# random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
|
|
||||||
# all_items = stylist_item + random_items
|
|
||||||
|
|||||||
@@ -1,15 +1,4 @@
|
|||||||
# 这个文件用来储存所有的category和occasion,这是标准文件。
|
# 这个文件用来储存所有的category和occasion,这是标准文件。
|
||||||
|
|
||||||
CATEGORY = [
|
|
||||||
'shoes', 'bags', 'dresses', 'tops', 'pants', 'skirts', 'outerwear', 'swimwear', 'suits',
|
|
||||||
'watches', 'sunglasses', 'belts', 'hats', 'jewelry', 'neckties', 'scarves & shawls'
|
|
||||||
]
|
|
||||||
CLOTHING_CATEGORY = [
|
|
||||||
'shoes', 'bags', 'dresses', 'tops', 'pants', 'skirts', 'outerwear', 'swimwear'
|
|
||||||
]
|
|
||||||
ACCESSORY_CATEGORY = [
|
|
||||||
'watches', 'sunglasses', 'belts', 'hats', 'jewelry', 'neckties', 'scarves & shawls'
|
|
||||||
]
|
|
||||||
OCCASION = [
|
OCCASION = [
|
||||||
"Casual", "Formal", "Activewear", "Resort", "Evening", "Outdoor",
|
"Casual", "Formal", "Activewear", "Resort", "Evening", "Outdoor",
|
||||||
"Business / workwear", "Cocktail / Semi-Formal", "Black Tie / White Tie",
|
"Business / workwear", "Cocktail / Semi-Formal", "Black Tie / White Tie",
|
||||||
@@ -17,3 +6,54 @@ OCCASION = [
|
|||||||
"Travel / Transit", "Athleisure", "Beach / Swim", "Ski / Snow / Mountain",
|
"Travel / Transit", "Athleisure", "Beach / Swim", "Ski / Snow / Mountain",
|
||||||
"Garden Party / Daytime Event"
|
"Garden Party / Daytime Event"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
CATEGORY = {
|
||||||
|
'clothing': [
|
||||||
|
'coats',
|
||||||
|
'jackets',
|
||||||
|
'blazers',
|
||||||
|
'puffer',
|
||||||
|
'cardigan',
|
||||||
|
'sweater',
|
||||||
|
'shirts',
|
||||||
|
't-shirts',
|
||||||
|
'pullover',
|
||||||
|
'polos',
|
||||||
|
'bodysuits',
|
||||||
|
'dresses',
|
||||||
|
'skirts',
|
||||||
|
'jeans',
|
||||||
|
'shorts',
|
||||||
|
'leggings',
|
||||||
|
'jumpsuits',
|
||||||
|
'swimwear',
|
||||||
|
],
|
||||||
|
'shoes': [
|
||||||
|
'sneakers',
|
||||||
|
'formal shoes',
|
||||||
|
'heels',
|
||||||
|
'flats',
|
||||||
|
'sandals',
|
||||||
|
'slides',
|
||||||
|
'boots',
|
||||||
|
],
|
||||||
|
'bags': [
|
||||||
|
'bags'
|
||||||
|
],
|
||||||
|
'accessories': [
|
||||||
|
'necklaces',
|
||||||
|
'bracelets',
|
||||||
|
'jewellery',
|
||||||
|
'eyewear',
|
||||||
|
'scarves',
|
||||||
|
'hats',
|
||||||
|
'gloves',
|
||||||
|
'belts',
|
||||||
|
'socks',
|
||||||
|
'watches'
|
||||||
|
'ties',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
ALL_CATEGORY = sum(CATEGORY.values(), [])
|
||||||
|
|
||||||
|
IGNORE_CATEGORY = ['socks']
|
||||||
92
data_ingestion/README.md
Normal file
92
data_ingestion/README.md
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
## Steps
|
||||||
|
1. Prepare products-all.json and image_data (folder) using javascript to download. These files should be saved in `./data/BATCH_SOURCE` which is a new folder. Give a new batch_source id to each new incoming data.
|
||||||
|
1. Run `process_item.py` to categorize category, gender and occasions for each data. Output to `./data/{BATCH_SOURCE}/metadata_extraction.json`. This should be running on H200 device.
|
||||||
|
3. Organize all data and then embed them into db locally using `run_ingestion.py`
|
||||||
|
|
||||||
|
## Raw Data Structure
|
||||||
|
```json
|
||||||
|
## products-all.json
|
||||||
|
{
|
||||||
|
"id": "BUL808",
|
||||||
|
"name": "SARAH ZHUANG - 'Click & Link' diamond 18k gold earrings",
|
||||||
|
"brand": "SARAH ZHUANG",
|
||||||
|
"category": "Fine Jewellery And Watches",
|
||||||
|
"subcategory": "General",
|
||||||
|
"price": 17500,
|
||||||
|
"currency": "HKD",
|
||||||
|
"description": "Sarah Zhuang's Click & Link earrings embrace the allure of geometry. Forged into elegant rectangles with one side encrusted with diamonds, this gold pair will certainly elevate your cocktail ensembles.",
|
||||||
|
"tags": [
|
||||||
|
"sarah zhuang",
|
||||||
|
"fine jewellery and watches",
|
||||||
|
"in-stock",
|
||||||
|
"new",
|
||||||
|
"sarah",
|
||||||
|
"zhuang",
|
||||||
|
"'click",
|
||||||
|
"link'",
|
||||||
|
"diamond"
|
||||||
|
],
|
||||||
|
"imageUrl": "https://media.lanecrawford.com/B/U/L/BUL808_in_xl.jpg",
|
||||||
|
"url": "https://www.lanecrawford.com.hk/product/sarah-zhuang/-click-link-diamond-18k-gold-earrings/_/BUL808/product.lc?utm_medium=embed&utm_source=ai-recommended&utm_campaign=2025-christmas_lc_ai-recommended",
|
||||||
|
"color": "YELLOW GOLD",
|
||||||
|
"groupName": "Fine Jewellery",
|
||||||
|
"deptName": "Women's Fine Jewellery",
|
||||||
|
"onlineBU": "Fine Jewellery",
|
||||||
|
"stockAvailability": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Example in `metadata_extraction.json`
|
||||||
|
```json
|
||||||
|
"EOJ367": {
|
||||||
|
"category": "shoes",
|
||||||
|
"gender": "female",
|
||||||
|
"applicable_occasions": [
|
||||||
|
"Casual",
|
||||||
|
"Outdoor",
|
||||||
|
"Travel / Transit"
|
||||||
|
],
|
||||||
|
"inappropriate_occasions": [
|
||||||
|
"Formal",
|
||||||
|
"Black Tie / White Tie",
|
||||||
|
"Bridal / Wedding",
|
||||||
|
"Business / workwear",
|
||||||
|
"Cocktail / Semi-Formal"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Metadata in Vector Database
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
'item_id': 'EOJ128',
|
||||||
|
'category': 'sunglasses',
|
||||||
|
'gender': 'unisex',
|
||||||
|
'modality': 'image',
|
||||||
|
'brand': 'CELINE',
|
||||||
|
'color': 'BROWN',
|
||||||
|
'description': "Immerse yourself in the depth of classic style with CELINE\'s Tortoiseshell Logo Sunglasses. Featuring a rich, tortoiseshell acetate frame and adorned with the iconic CELINE logo in gold, these sunglasses are a testament to timeless elegance and luxury. Perfect for those who appreciate a sophisticated aesthetic, they offer optimal UV protection while ensuring you remain at the forefront of fashion.",
|
||||||
|
'tags': 'celine,accessories,in-stock,new,maxi,triomphe,acetate,round',
|
||||||
|
'price': 4500,
|
||||||
|
'url': 'https://www.lanecrawford.com.hk/product/celine/maxi-triomphe-acetate-round-sunglasses/_/EOJ128/product.lc?utm_medium=embed&utm_source=ai-recommended&utm_campaign=2025-christmas_lc_ai-recommended',
|
||||||
|
'batch_source': '2025_q4',
|
||||||
|
'Outdoor': 0,
|
||||||
|
'Ski / Snow / Mountain': 0,
|
||||||
|
'Festival / Concert': 0,
|
||||||
|
'Activewear': 0,
|
||||||
|
'Casual': 1,
|
||||||
|
'Cocktail / Semi-Formal': -1,
|
||||||
|
'Formal': -1,
|
||||||
|
'Party / Clubbing': 0,
|
||||||
|
'Evening': 0,
|
||||||
|
'Travel / Transit': 0,
|
||||||
|
'Beach / Swim': 0,
|
||||||
|
'Garden Party / Daytime Event': 1,
|
||||||
|
'Black Tie / White Tie': -1,
|
||||||
|
'Resort': 1,
|
||||||
|
'Athleisure': 0,
|
||||||
|
'Business / workwear': -1,
|
||||||
|
'Bridal / Wedding': -1,
|
||||||
|
}
|
||||||
|
```
|
||||||
280
data_ingestion/process_item.py
Normal file
280
data_ingestion/process_item.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from transformers import AutoProcessor, AutoModelForVision2Seq
|
||||||
|
from PIL import Image
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from app.taxonomy import OCCASION, CATEGORY, ALL_CATEGORY
|
||||||
|
|
||||||
|
|
||||||
|
# data config
|
||||||
|
BATCH_SOURCE = '2025_q4'
|
||||||
|
RAW_DATA_PATH = f'./data/{BATCH_SOURCE}/products-all.json'
|
||||||
|
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
|
||||||
|
|
||||||
|
# MLLM config
|
||||||
|
MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
|
DEVICE = "cuda:0" # 确保设备设置正确,与您的 Traceback 匹配
|
||||||
|
BATCH_SIZE = 50
|
||||||
|
OUTPUT_FILE = f'./data/{BATCH_SOURCE}/metadata_extraction.json'
|
||||||
|
|
||||||
|
|
||||||
|
# Load Model
|
||||||
|
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
||||||
|
if processor.tokenizer.padding_side != 'left':
|
||||||
|
processor.tokenizer.padding_side = 'left'
|
||||||
|
print(f"Set tokenizer padding_side to '{processor.tokenizer.padding_side}' for correct generation.")
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16).to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
# Load Data
|
||||||
|
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
|
||||||
|
data = json.load(file)
|
||||||
|
|
||||||
|
|
||||||
|
EXAMPLE_1_INFO = """
|
||||||
|
Product Name: ARMARIUM - Loren Wool Blend Tube Skirt
|
||||||
|
Category: Clothing / Bottoms
|
||||||
|
Color: RED
|
||||||
|
Description: Cut from cardinal-red virgin wool, Armarium's Loren skirt wields tailoring's exactitude in a column of colour. The low-slung waist and clean tube line are punctuated by a razor back slit—stride from boardroom to candlelit bar with modern hauteur.
|
||||||
|
Tags: armarium, clothing, in-stock, new, loren, wool, blend, tube
|
||||||
|
"""
|
||||||
|
EXAMPLE_1_JSON = json.dumps({
|
||||||
|
"category": "skirts",
|
||||||
|
"gender": "female",
|
||||||
|
"applicable_occasions": [
|
||||||
|
"Business/workwear", "Evening", "Cocktail / Semi-Formal", "Party / Clubbing", "Formal"
|
||||||
|
],
|
||||||
|
"inappropriate_occasions": [
|
||||||
|
"Activewear", "Beach / Swim", "Athleisure", "Ski / Snow / Mountain", "Casual"
|
||||||
|
]
|
||||||
|
}, indent=4)
|
||||||
|
|
||||||
|
# 示例 2:胸针 (Pin)
|
||||||
|
EXAMPLE_2_INFO = """
|
||||||
|
Product Name: TATEOSSIAN - Mayfair 18K Yellow Gold Rhodium Plated Sterling Silver Peg Pin
|
||||||
|
Category: Accessories / Accessories
|
||||||
|
Color: MULTI
|
||||||
|
Description: Crafted from 18k yellow gold and rhodium-plated sterling silver, this unique pins has been artfully finished with Tateossian's signature diamond engraving pattern.
|
||||||
|
Tags: tateossian, accessories, in-stock, new, mayfair, yellow, gold, rhodium
|
||||||
|
"""
|
||||||
|
EXAMPLE_2_JSON = json.dumps({
|
||||||
|
"category": "jewelry",
|
||||||
|
"gender": "female",
|
||||||
|
"applicable_occasions": [
|
||||||
|
"Formal", "Black Tie / White Tie", "Bridal / Wedding", "Business/workwear", "Cocktail / Semi-Formal"
|
||||||
|
],
|
||||||
|
"inappropriate_occasions": [
|
||||||
|
"Casual", "Activewear", "Beach / Swim", "Outdoor", "Athleisure", "Ski / Snow / Mountain"
|
||||||
|
]
|
||||||
|
}, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 2. 构造对话格式 Prompt ---
|
||||||
|
BOS_TOKEN = "<|begin_of_text|>"
|
||||||
|
EOS_TOKEN = "<|eot_id|>"
|
||||||
|
SYSTEM_HEADER = "<|start_header_id|>system<|end_header_id|>\n"
|
||||||
|
USER_HEADER = "<|start_header_id|>user<|end_header_id|>\n"
|
||||||
|
ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n"
|
||||||
|
IMAGE_TOKEN = "<|image|>"
|
||||||
|
|
||||||
|
def format_product_info(product):
|
||||||
|
tags_str = ", ".join(product.get('tags', []))
|
||||||
|
info = (
|
||||||
|
f"Product Name: {product.get('name', 'N/A')}\n"
|
||||||
|
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
|
||||||
|
f"Color: {product.get('color', 'N/A')}\n"
|
||||||
|
f"Description: {product.get('description', '')}\n"
|
||||||
|
f"Tags: {tags_str}",
|
||||||
|
f"groupName: {product.get('groupName', 'N/A')}\n"
|
||||||
|
f"onlineBU: {product.get('onlineBU', 'N/A')}\n"
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def generate_full_prompt(product_info, raw_category):
|
||||||
|
if raw_category == 'Fine Jewellery And Watches':
|
||||||
|
category = 'accessories'
|
||||||
|
else:
|
||||||
|
category = raw_category.lower()
|
||||||
|
subcategory_list = CATEGORY.get(category)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = f"""You are an expert fashion AI assistant. Your task is to analyze the provided product image and product details to:
|
||||||
|
1. determine the suitable occasions for wearing or using the item. You must choose occasions ONLY from the following strict list: {json.dumps(OCCASION, indent=4)}. Only relevant suitable or inappropriate occasions should be selected.
|
||||||
|
2. categorize it into suitable category in strict list: {json.dumps(subcategory_list)}.
|
||||||
|
3. categorize it into appropriate gender in ["female", "male", "unisex"]
|
||||||
|
|
||||||
|
Output Format:
|
||||||
|
Return ONLY a valid JSON object with four keys: "category", "gender", "applicable_occasions" and "inappropriate_occasions". Do not include any analysis or extra text outside of the final JSON object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 组合对话序列
|
||||||
|
dialogue_prompt = (
|
||||||
|
# 1. System Instruction
|
||||||
|
f"{BOS_TOKEN}{SYSTEM_HEADER}{SYSTEM_PROMPT}{EOS_TOKEN}"
|
||||||
|
|
||||||
|
# 2. Example 1 (Few-Shot Round 1)
|
||||||
|
# 格式: <|start_header_id|>user<|end_header_id|>\n<|image|>\n{Text Instruction}<|eot_id|>
|
||||||
|
f"{USER_HEADER}\n{EXAMPLE_1_INFO}{EOS_TOKEN}"
|
||||||
|
f"{ASSISTANT_HEADER}{EXAMPLE_1_JSON}{EOS_TOKEN}"
|
||||||
|
|
||||||
|
# 3. Example 2 (Few-Shot Round 2)
|
||||||
|
f"{USER_HEADER}\n{EXAMPLE_2_INFO}{EOS_TOKEN}"
|
||||||
|
f"{ASSISTANT_HEADER}{EXAMPLE_2_JSON}{EOS_TOKEN}"
|
||||||
|
|
||||||
|
# 4. Target Item (Target Query)
|
||||||
|
f"{USER_HEADER}{IMAGE_TOKEN}\nInput Data:\n{product_info}{EOS_TOKEN}"
|
||||||
|
f"{ASSISTANT_HEADER}" # 最后的 Assistant Header 告诉模型从这里开始生成
|
||||||
|
)
|
||||||
|
return dialogue_prompt
|
||||||
|
|
||||||
|
|
||||||
|
# 2. 加载数据
|
||||||
|
products = data['products']
|
||||||
|
product_list = [
|
||||||
|
product for product in products
|
||||||
|
if product.get('category') in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']
|
||||||
|
and os.path.exists(os.path.join(IMAGE_DIR, f"{product.get('id')}.jpg"))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_results():
|
||||||
|
if os.path.exists(OUTPUT_FILE):
|
||||||
|
with open(OUTPUT_FILE, 'r') as f:
|
||||||
|
final_results = json.load(f)
|
||||||
|
else:
|
||||||
|
final_results = {}
|
||||||
|
|
||||||
|
unfinished_ids = []
|
||||||
|
for product in product_list:
|
||||||
|
item_id = product.get('id')
|
||||||
|
if item_id not in final_results.keys():
|
||||||
|
unfinished_ids.append(product)
|
||||||
|
else:
|
||||||
|
processed_item = final_results[item_id]
|
||||||
|
category = processed_item.get("category")
|
||||||
|
gender = processed_item.get("gender")
|
||||||
|
|
||||||
|
if category not in ALL_CATEGORY:
|
||||||
|
unfinished_ids.append(product)
|
||||||
|
|
||||||
|
if gender not in ['female', 'male', 'unisex']:
|
||||||
|
unfinished_ids.append(product)
|
||||||
|
return unfinished_ids, final_results
|
||||||
|
|
||||||
|
attemps = 0
|
||||||
|
while attemps < 3:
|
||||||
|
attemps += 1
|
||||||
|
unfinished_products, final_results = validate_results()
|
||||||
|
completion_ratio = len(unfinished_products) / len(product_list)
|
||||||
|
if (completion_ratio > 0.95):
|
||||||
|
print("valid results surpass 95%. Finish Now.")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"Start {attemps} categorization process. Current ratio: {completion_ratio * 100}%")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 按照 BATCH_SIZE 进行切片迭代
|
||||||
|
for i in tqdm(range(0, len(unfinished_products), BATCH_SIZE)):
|
||||||
|
batch_samples = unfinished_products[i:i + BATCH_SIZE]
|
||||||
|
|
||||||
|
target_images = []
|
||||||
|
target_prompts = []
|
||||||
|
target_products_in_batch = []
|
||||||
|
|
||||||
|
# 准备当前批次的输入数据
|
||||||
|
for product in batch_samples:
|
||||||
|
product_id = product['id']
|
||||||
|
raw_category = product.get('category')
|
||||||
|
image_path = os.path.join(IMAGE_DIR, f"{product_id}.jpg")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 收集图片、Prompt 和产品数据
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
product_info = format_product_info(product)
|
||||||
|
full_prompt = generate_full_prompt(product_info, raw_category)
|
||||||
|
|
||||||
|
target_images.append(image)
|
||||||
|
target_prompts.append(full_prompt)
|
||||||
|
target_products_in_batch.append(product)
|
||||||
|
except Exception as e:
|
||||||
|
# 跳过任何加载失败的单个样本
|
||||||
|
print(f"Skipping product {product_id} due to loading error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not target_images:
|
||||||
|
continue # 如果整个批次都没有有效图片,跳过
|
||||||
|
|
||||||
|
# 4. 批量推理
|
||||||
|
print(f"\nProcessing batch {i//BATCH_SIZE + 1}/{int(len(unfinished_products)/BATCH_SIZE)+1} (Size: {len(target_images)})...")
|
||||||
|
|
||||||
|
# 处理器输入:使用嵌套列表 [[img1], [img2], ...]
|
||||||
|
inputs = processor(
|
||||||
|
images=[[img] for img in target_images],
|
||||||
|
text=target_prompts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True
|
||||||
|
).to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=150,
|
||||||
|
do_sample=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 批量解码和解析结果
|
||||||
|
input_lengths = inputs.input_ids.size(1)
|
||||||
|
|
||||||
|
for j in range(len(target_products_in_batch)):
|
||||||
|
product = target_products_in_batch[j]
|
||||||
|
product_id = product['id']
|
||||||
|
|
||||||
|
# 提取当前 item 的生成结果
|
||||||
|
# 注意: outputs 是 [batch_size, sequence_length]
|
||||||
|
newly_generated_tokens = outputs[j, input_lengths:]
|
||||||
|
generated_text = processor.decode(newly_generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# 清理和解析
|
||||||
|
if generated_text.endswith(processor.tokenizer.eos_token):
|
||||||
|
generated_text = generated_text[:-len(processor.tokenizer.eos_token)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_idx = generated_text.find('{')
|
||||||
|
end_idx = generated_text.rfind('}') + 1
|
||||||
|
|
||||||
|
if start_idx == -1 or end_idx == -1:
|
||||||
|
raise ValueError("JSON start or end delimiter not found.")
|
||||||
|
|
||||||
|
json_str = generated_text[start_idx:end_idx]
|
||||||
|
result_dict = json.loads(json_str)
|
||||||
|
|
||||||
|
final_results[product_id] = result_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ID {product_id}: FAILED to parse JSON. Raw Output: {generated_text.strip()}")
|
||||||
|
final_results[product_id] = {"error": str(e), "raw_output": generated_text.strip()}
|
||||||
|
|
||||||
|
# 显存清理(可选,但在长任务中推荐)
|
||||||
|
del inputs, outputs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(final_results, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 6. 保存最终结果
|
||||||
|
print("\n\n=== ALL BATCHES COMPLETE ===")
|
||||||
|
|
||||||
|
# 保存最终结果到 JSON 文件
|
||||||
|
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(final_results, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"Results saved to {OUTPUT_FILE}")
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n--- Execution Error ---")
|
||||||
|
print(f"An unexpected error occurred: {e}")
|
||||||
178
data_ingestion/run_ingestion.py
Normal file
178
data_ingestion/run_ingestion.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import CLIPProcessor, CLIPModel
|
||||||
|
|
||||||
|
from app.taxonomy import CATEGORY, ALL_CATEGORY, OCCASION
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_SOURCE = '2025_q4'
|
||||||
|
DATA_DIR = f'./data/{BATCH_SOURCE}'
|
||||||
|
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
|
||||||
|
|
||||||
|
RAW_DATA_PATH = f'{DATA_DIR}/products-all.json'
|
||||||
|
CATEGORIZED_METADATA_PATH = f'{DATA_DIR}/metadata_extraction.json'
|
||||||
|
|
||||||
|
## Load data
|
||||||
|
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
|
||||||
|
raw_data = json.load(file)
|
||||||
|
|
||||||
|
with open(CATEGORIZED_METADATA_PATH, 'r', encoding='utf-8') as file:
|
||||||
|
categorized_data = json.load(file)
|
||||||
|
|
||||||
|
|
||||||
|
# Create Collection
|
||||||
|
client = chromadb.PersistentClient(path='./data/db')
|
||||||
|
collection = client.get_or_create_collection(
|
||||||
|
name="lc_clothing_embedding"
|
||||||
|
)
|
||||||
|
|
||||||
|
# if you wish to delete some item, uncomment following
|
||||||
|
# results = collection.delete(
|
||||||
|
# where={
|
||||||
|
# "batch_source": BATCH_SOURCE
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
def format_product_info(product):
|
||||||
|
tags_str = ", ".join(product.get('tags', []))
|
||||||
|
info = (
|
||||||
|
f"Product Name: {product.get('name', 'N/A')}\n"
|
||||||
|
f"Brand: {product.get('brand', 'N/A')}\n"
|
||||||
|
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
|
||||||
|
f"Color: {product.get('color', 'N/A')}\n"
|
||||||
|
f"Description: {product.get('description', '')}\n"
|
||||||
|
f"Tags: {tags_str}"
|
||||||
|
f"GroupName: {product.get('groupName', 'N/A')}\n"
|
||||||
|
f"DetpName: {product.get('deptName', 'N/A')}\n"
|
||||||
|
f"OnlineBU: {product.get('onlineBU', 'N/A')}\n"
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
# Combine all data together
|
||||||
|
new_category = {}
|
||||||
|
valid_count = 0
|
||||||
|
all_count = 0
|
||||||
|
for raw_item in tqdm(raw_data['products']):
|
||||||
|
item_id = raw_item.get('id')
|
||||||
|
if not item_id:
|
||||||
|
print(f"This item {raw_item} did not have a valid item_id")
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_category = raw_item.get("category")
|
||||||
|
if raw_category not in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_path = os.path.join(IMAGE_DIR, f"{item_id}.jpg")
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
print(f"Image not found: {image_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# All above is raw data error, it's not our business.
|
||||||
|
all_count += 1
|
||||||
|
|
||||||
|
processed_item = categorized_data.get(item_id, {})
|
||||||
|
if not processed_item:
|
||||||
|
print(f"{item_id} has not been categorized. It does not exist in {CATEGORIZED_METADATA_PATH}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
category = processed_item.get("category")
|
||||||
|
gender = processed_item.get("gender")
|
||||||
|
applicable_occasions = processed_item.get("applicable_occasions", [])
|
||||||
|
inappropriate_occasions = processed_item.get("inappropriate_occasions", [])
|
||||||
|
|
||||||
|
if category not in ALL_CATEGORY:
|
||||||
|
print(f"{item_id}'s category, {category}, does not valid.")
|
||||||
|
if category not in new_category:
|
||||||
|
new_category[category] = [item_id]
|
||||||
|
else:
|
||||||
|
new_category[category].append(item_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if gender not in ['female', 'male', 'unisex']:
|
||||||
|
print(f"{item_id}'s gender is not valid in {['female', 'male', 'unisex']}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
occasions = applicable_occasions + inappropriate_occasions
|
||||||
|
if not set(occasions).issubset(set(OCCASION)):
|
||||||
|
# print(f"{item_id}'s some occasions is not vaild. \n Invalid occasion is {set(occasions).difference(set(OCCASION))}")
|
||||||
|
applicable_occasions = [o for o in applicable_occasions if o in OCCASION]
|
||||||
|
inappropriate_occasions = [o for o in inappropriate_occasions if o in OCCASION]
|
||||||
|
|
||||||
|
description = raw_item.get('description', "")
|
||||||
|
if not description:
|
||||||
|
f"{item_id}'s description is lost."
|
||||||
|
continue
|
||||||
|
|
||||||
|
url = raw_item.get('url', '')
|
||||||
|
if not url:
|
||||||
|
f"{item_id}'s url is lost."
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid_count += 1
|
||||||
|
# Prepare metadata for db
|
||||||
|
item_img_metadata = {
|
||||||
|
"item_id": item_id,
|
||||||
|
"category": category,
|
||||||
|
"description": description,
|
||||||
|
"gender": gender,
|
||||||
|
'brand': raw_item.get('brand', ''),
|
||||||
|
'color': raw_item.get('color', ''),
|
||||||
|
'price': raw_item.get('price', ''),
|
||||||
|
'tags': ",".join(raw_item.get('tags', [])),
|
||||||
|
'url': url,
|
||||||
|
"modality": "image",
|
||||||
|
"batch_source": BATCH_SOURCE
|
||||||
|
}
|
||||||
|
for occasion in OCCASION:
|
||||||
|
item_img_metadata[occasion] = 0
|
||||||
|
for occasion in applicable_occasions:
|
||||||
|
item_img_metadata[occasion] = 1
|
||||||
|
for occasion in inappropriate_occasions:
|
||||||
|
item_img_metadata[occasion] = -1
|
||||||
|
|
||||||
|
item_txt_metadata = deepcopy(item_img_metadata)
|
||||||
|
item_txt_metadata["modality"] = "text"
|
||||||
|
|
||||||
|
|
||||||
|
# Get image feature
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
inputs = processor(images=image, return_tensors="pt").to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
img_features = model.get_image_features(**inputs)
|
||||||
|
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
|
||||||
|
img_embedding = img_features.cpu().numpy().flatten().tolist()
|
||||||
|
|
||||||
|
# Get text feature
|
||||||
|
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
txt_features = model.get_text_features(**inputs)
|
||||||
|
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
|
||||||
|
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
|
||||||
|
|
||||||
|
product_info = format_product_info(raw_item)
|
||||||
|
# 插入到 ChromaDB
|
||||||
|
collection.add(
|
||||||
|
ids=[f'{item_id}_img', f'{item_id}_txt'],
|
||||||
|
documents=[product_info, product_info],
|
||||||
|
embeddings=[img_embedding, txt_embedding],
|
||||||
|
metadatas=[item_img_metadata, item_txt_metadata],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")
|
||||||
|
print(f'Found new category for consideration: {new_category}')
|
||||||
BIN
docs/Edi.docx
Normal file
BIN
docs/Edi.docx
Normal file
Binary file not shown.
263
docs/LC Recommendation Workflow.drawio
Normal file
263
docs/LC Recommendation Workflow.drawio
Normal file
File diff suppressed because one or more lines are too long
BIN
docs/LC Recommendation Workflow.pdf
Normal file
BIN
docs/LC Recommendation Workflow.pdf
Normal file
Binary file not shown.
BIN
docs/LC Stylist Rules 总结.docx
Normal file
BIN
docs/LC Stylist Rules 总结.docx
Normal file
Binary file not shown.
BIN
docs/vera.docx
Normal file
BIN
docs/vera.docx
Normal file
Binary file not shown.
Reference in New Issue
Block a user