import random import chromadb from typing import Set, List, Dict, Union, Any import torch from PIL import Image from torch import no_grad from transformers import CLIPModel, CLIPProcessor from app.server.utils.minio_client import oss_get_image, minio_client from app.server.utils.minio_config import MINIO_LC_DATA_PATH # --- 你的配置 --- DB_PATH = "/workspace/lc_stylist_agent/db" COLLECTION_NAME = 'lc_clothing_embedding' # 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大) MAX_LIMIT = 1000000 client = chromadb.PersistentClient(path=DB_PATH) try: collection = client.get_collection(name=COLLECTION_NAME) print(f"✅ 连接到 Collection: {COLLECTION_NAME}") except ValueError: print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。") # 如果 collection 不存在,我们将跳过后续操作 collection = None from transformers import CLIPModel, CLIPProcessor def get_clip_embedding(data: str | Image.Image) -> List[float]: """生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。""" embedding_model_name = "openai/clip-vit-base-patch32" device = "cuda" if torch.cuda.is_available() else "cpu" model = CLIPModel.from_pretrained(embedding_model_name).to(device) processor = CLIPProcessor.from_pretrained(embedding_model_name) # 强制截断,解决序列长度问题 inputs = processor( text=[data], return_tensors="pt", padding=True, truncation=True ).to(device) with no_grad(): features = model.get_text_features(**inputs) # L2 归一化 features = features / features.norm(p=2, dim=-1, keepdim=True) return features.cpu().numpy().flatten().tolist() def query_local_db(embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]: """ 基于嵌入向量在本地数据库中查询相似单品。 实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。 """ # 实际应执行向量查询 # 为了演示流程,返回一个模拟结果 results = collection.query( query_embeddings=[embedding], n_results=n_results, where={ "$and": [ {"category": category}, {"modality": "image"}, ] }, include=['documents', 'metadatas', 'distances'] ) return results if __name__ == '__main__': embedding = get_clip_embedding("watch") print(embedding) result = query_local_db(embedding, "Watches", 20) print(result) ids = result['ids'][0] random_single_id = random.choices(ids, k=2) print(random_single_id) # for id in ids: # path = id.replace("_img", ".jpg") # img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB') # img.save(path)