Files
lc_stylist_agent/app/test/chromadb/embedding_query.py

91 lines
2.8 KiB
Python
Raw Permalink Normal View History

import random
import chromadb
from typing import Set, List, Dict, Union, Any
import torch
from PIL import Image
from torch import no_grad
from transformers import CLIPModel, CLIPProcessor
from app.server.utils.minio_client import oss_get_image, minio_client
from app.server.utils.minio_config import MINIO_LC_DATA_PATH
# --- 你的配置 ---
DB_PATH = "/workspace/lc_stylist_agent/db"
COLLECTION_NAME = 'lc_clothing_embedding'
# 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大)
MAX_LIMIT = 1000000
client = chromadb.PersistentClient(path=DB_PATH)
try:
collection = client.get_collection(name=COLLECTION_NAME)
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
except ValueError:
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
# 如果 collection 不存在,我们将跳过后续操作
collection = None
from transformers import CLIPModel, CLIPProcessor
def get_clip_embedding(data: str | Image.Image) -> List[float]:
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
embedding_model_name = "openai/clip-vit-base-patch32"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained(embedding_model_name).to(device)
processor = CLIPProcessor.from_pretrained(embedding_model_name)
# 强制截断,解决序列长度问题
inputs = processor(
text=[data],
return_tensors="pt",
padding=True,
truncation=True
).to(device)
with no_grad():
features = model.get_text_features(**inputs)
# L2 归一化
features = features / features.norm(p=2, dim=-1, keepdim=True)
return features.cpu().numpy().flatten().tolist()
def query_local_db(embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]:
"""
基于嵌入向量在本地数据库中查询相似单品
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)
"""
# 实际应执行向量查询
# 为了演示流程,返回一个模拟结果
results = collection.query(
query_embeddings=[embedding],
n_results=n_results,
where={
"$and": [
{"category": category},
{"modality": "image"},
]
},
include=['documents', 'metadatas', 'distances']
)
return results
if __name__ == '__main__':
embedding = get_clip_embedding("watch")
print(embedding)
result = query_local_db(embedding, "Watches", 20)
print(result)
ids = result['ids'][0]
random_single_id = random.choices(ids, k=2)
print(random_single_id)
# for id in ids:
# path = id.replace("_img", ".jpg")
# img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB')
# img.save(path)