新增 取消agent配饰(保留鞋子)推荐,改为默认随机配饰搭配 使用json文件补充stylist删除掉的必要配饰
This commit is contained in:
@@ -40,8 +40,7 @@ COPY . /app
|
||||
# Install litserve and requirements
|
||||
RUN pip install --upgrade pip setuptools wheel
|
||||
RUN pip install --no-cache-dir litserve==0.2.16 -r requirements.txt
|
||||
RUN pip install torch torchvision
|
||||
|
||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m","app.main"]
|
||||
#CMD ["tail", "-f","/dev/null"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
@@ -115,10 +116,10 @@ if __name__ == '__main__':
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
|
||||
url = [
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/d9df7c48-c7e5-47f9-be67-07f0d175d202.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/ddf39b9c-69f0-4b28-95ed-9d823fa82e35.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/112194a0-dc1d-4151-8c58-82642142a553.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251120/788007f1-e44b-4390-ad9e-a2d4ba406379.jpg'
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251121/4b595d3b-5d3d-4617-ae09-5fca92d935f7.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251121/6d0d7540-5b61-45f2-a1fa-5cb1c7a3d0fa.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251121/a4e51ccb-9b95-4718-8153-92ee0a39d0c8.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251121/cbebbcf6-cca2-4460-9f9f-d0b1000dc2cd.jpg'
|
||||
]
|
||||
read_type = "1"
|
||||
for id, i in enumerate(url):
|
||||
|
||||
90
app/test/chromadb/embedding_query.py
Normal file
90
app/test/chromadb/embedding_query.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import random
|
||||
|
||||
import chromadb
|
||||
from typing import Set, List, Dict, Union, Any
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import no_grad
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from app.server.utils.minio_client import oss_get_image, minio_client
|
||||
from app.server.utils.minio_config import MINIO_LC_DATA_PATH
|
||||
|
||||
# --- 你的配置 ---
|
||||
DB_PATH = "/workspace/lc_stylist_agent/db"
|
||||
COLLECTION_NAME = 'lc_clothing_embedding'
|
||||
# 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大)
|
||||
MAX_LIMIT = 1000000
|
||||
|
||||
client = chromadb.PersistentClient(path=DB_PATH)
|
||||
try:
|
||||
collection = client.get_collection(name=COLLECTION_NAME)
|
||||
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
|
||||
except ValueError:
|
||||
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
|
||||
# 如果 collection 不存在,我们将跳过后续操作
|
||||
collection = None
|
||||
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
|
||||
def get_clip_embedding(data: str | Image.Image) -> List[float]:
|
||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||
|
||||
embedding_model_name = "openai/clip-vit-base-patch32"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = CLIPModel.from_pretrained(embedding_model_name).to(device)
|
||||
processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||
|
||||
# 强制截断,解决序列长度问题
|
||||
inputs = processor(
|
||||
text=[data],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True
|
||||
).to(device)
|
||||
with no_grad():
|
||||
features = model.get_text_features(**inputs)
|
||||
|
||||
# L2 归一化
|
||||
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
return features.cpu().numpy().flatten().tolist()
|
||||
|
||||
|
||||
def query_local_db(embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
基于嵌入向量在本地数据库中查询相似单品。
|
||||
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
|
||||
"""
|
||||
# 实际应执行向量查询
|
||||
# 为了演示流程,返回一个模拟结果
|
||||
results = collection.query(
|
||||
query_embeddings=[embedding],
|
||||
n_results=n_results,
|
||||
where={
|
||||
"$and": [
|
||||
{"category": category},
|
||||
{"modality": "image"},
|
||||
]
|
||||
},
|
||||
include=['documents', 'metadatas', 'distances']
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
embedding = get_clip_embedding("watch")
|
||||
print(embedding)
|
||||
result = query_local_db(embedding, "Watches", 20)
|
||||
print(result)
|
||||
ids = result['ids'][0]
|
||||
|
||||
random_single_id = random.choices(ids, k=2)
|
||||
print(random_single_id)
|
||||
|
||||
# for id in ids:
|
||||
# path = id.replace("_img", ".jpg")
|
||||
# img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB')
|
||||
# img.save(path)
|
||||
@@ -13,4 +13,12 @@ services:
|
||||
- ./db:/db
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
ports:
|
||||
- "10070:8000"
|
||||
- "10070:8000"
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
# 告诉 Docker 使用所有可用的 NVIDIA GPU
|
||||
- driver: nvidia
|
||||
device_ids: ['0']
|
||||
capabilities: [ gpu ]
|
||||
Reference in New Issue
Block a user