91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
import random
|
|
|
|
import chromadb
|
|
from typing import Set, List, Dict, Union, Any
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from torch import no_grad
|
|
from transformers import CLIPModel, CLIPProcessor
|
|
|
|
from app.server.utils.minio_client import oss_get_image, minio_client
|
|
from app.server.utils.minio_config import MINIO_LC_DATA_PATH
|
|
|
|
# --- 你的配置 ---
|
|
DB_PATH = "/workspace/lc_stylist_agent/db"
|
|
COLLECTION_NAME = 'lc_clothing_embedding'
|
|
# 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大)
|
|
MAX_LIMIT = 1000000
|
|
|
|
client = chromadb.PersistentClient(path=DB_PATH)
|
|
try:
|
|
collection = client.get_collection(name=COLLECTION_NAME)
|
|
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
|
|
except ValueError:
|
|
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
|
|
# 如果 collection 不存在,我们将跳过后续操作
|
|
collection = None
|
|
|
|
from transformers import CLIPModel, CLIPProcessor
|
|
|
|
|
|
def get_clip_embedding(data: str | Image.Image) -> List[float]:
|
|
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
|
|
|
embedding_model_name = "openai/clip-vit-base-patch32"
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = CLIPModel.from_pretrained(embedding_model_name).to(device)
|
|
processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
|
|
|
# 强制截断,解决序列长度问题
|
|
inputs = processor(
|
|
text=[data],
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True
|
|
).to(device)
|
|
with no_grad():
|
|
features = model.get_text_features(**inputs)
|
|
|
|
# L2 归一化
|
|
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
|
|
|
return features.cpu().numpy().flatten().tolist()
|
|
|
|
|
|
def query_local_db(embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]:
|
|
"""
|
|
基于嵌入向量在本地数据库中查询相似单品。
|
|
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
|
|
"""
|
|
# 实际应执行向量查询
|
|
# 为了演示流程,返回一个模拟结果
|
|
results = collection.query(
|
|
query_embeddings=[embedding],
|
|
n_results=n_results,
|
|
where={
|
|
"$and": [
|
|
{"category": category},
|
|
{"modality": "image"},
|
|
]
|
|
},
|
|
include=['documents', 'metadatas', 'distances']
|
|
)
|
|
return results
|
|
|
|
|
|
if __name__ == '__main__':
|
|
embedding = get_clip_embedding("watch")
|
|
print(embedding)
|
|
result = query_local_db(embedding, "Watches", 20)
|
|
print(result)
|
|
ids = result['ids'][0]
|
|
|
|
random_single_id = random.choices(ids, k=2)
|
|
print(random_single_id)
|
|
|
|
# for id in ids:
|
|
# path = id.replace("_img", ".jpg")
|
|
# img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB')
|
|
# img.save(path)
|