Merge remote-tracking branch 'origin/master'
This commit is contained in:
36
app/api/api_query_image.py
Normal file
36
app/api/api_query_image.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
import logging
|
||||
from http.client import HTTPException
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.schemas.query_image import QueryImageModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.search_image_with_text.service import query
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/query_image")
|
||||
def query_image(request_data: QueryImageModel):
|
||||
"""
|
||||
对话机器人
|
||||
创建一个具有以下参数的请求体:
|
||||
- **gender**: 性别
|
||||
- **content**: 用户输入的内容
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"gender": "male",
|
||||
"content": "give me a long sleeve blouse",
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
data = query(request_data.gender, request_data.content)
|
||||
logger.info(f"query_image response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"query_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
6
app/schemas/query_image.py
Normal file
6
app/schemas/query_image.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class QueryImageModel(BaseModel):
|
||||
gender: str
|
||||
content: str
|
||||
89
app/service/search_image_with_text/service.py
Normal file
89
app/service/search_image_with_text/service.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import chromadb
|
||||
import hashlib
|
||||
|
||||
import pandas as pd
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
||||
from tqdm import tqdm
|
||||
|
||||
# 读取 csv 文件
|
||||
# csv_file_path = r'D:/Files/csv/output/output.csv'
|
||||
# image_path = r'D:/images-clean'
|
||||
|
||||
# df = pd.read_csv(csv_file_path, encoding='Windows-1252')
|
||||
|
||||
# 创建 Chroma 客户端
|
||||
client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db"))
|
||||
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db"))
|
||||
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db"))
|
||||
# 创建集合
|
||||
embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large")
|
||||
|
||||
|
||||
# def create_collection():
|
||||
# collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn)
|
||||
#
|
||||
# # 存储数据,包括自定义属性
|
||||
# images_description = []
|
||||
# images_metadata = []
|
||||
# ids = []
|
||||
# batch_size = 41666 # 最大批量大小
|
||||
# for index, row in tqdm(df.iterrows()):
|
||||
# # 将图片的md5作为id
|
||||
# with open(image_path + row['path'], 'rb') as f:
|
||||
# image_data = f.read()
|
||||
# md5_value = hashlib.md5(image_data).hexdigest()
|
||||
# ids.append(md5_value)
|
||||
# images_description.append(row['description'])
|
||||
# images_metadata.append({
|
||||
# "gender": row['gender'],
|
||||
# "path": row['path']
|
||||
# })
|
||||
#
|
||||
# # 将数据添加到集合
|
||||
# # 每达到 batch_size 就执行一次 upsert
|
||||
# if len(ids) >= batch_size:
|
||||
# collection.upsert(
|
||||
# ids=list(ids),
|
||||
# documents=images_description,
|
||||
# metadatas=images_metadata # 添加自定义属性
|
||||
# )
|
||||
# # 清空列表以准备下一批数据
|
||||
# ids.clear()
|
||||
# images_description.clear()
|
||||
# images_metadata.clear()
|
||||
#
|
||||
# if ids:
|
||||
# collection.upsert(
|
||||
# ids=list(ids),
|
||||
# documents=images_description,
|
||||
# metadatas=images_metadata # 添加自定义属性
|
||||
# )
|
||||
#
|
||||
# print("Data successfully stored in the vector database.")
|
||||
|
||||
|
||||
def query(gender, content):
|
||||
collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn)
|
||||
# 6. 查询相似内容
|
||||
user_gender = gender # 用户输入的性别
|
||||
user_content = content # 用户输入的内容
|
||||
|
||||
results = collection.query(
|
||||
query_texts=user_content,
|
||||
n_results=5, # 返回前 5 个结果
|
||||
where={"gender": user_gender} # 根据性别过滤
|
||||
)
|
||||
|
||||
# 输出结果
|
||||
resp = []
|
||||
for document, result in zip(results['documents'][0], results['metadatas'][0]):
|
||||
# print("Path:", result['path'])
|
||||
# print("Content:", document)
|
||||
resp.append(result['path'])
|
||||
return resp
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# create_collection()
|
||||
query("female", "I need a long sleeve dress")
|
||||
Reference in New Issue
Block a user