Compare commits

...

10 Commits

Author SHA1 Message Date
zhouchengrong
5b227e3008 feat 新增 process lookbooks 接口
fix
2024-10-24 17:35:46 +08:00
shahaibo
07e72c1ee1 TASK:lookbook上传,查询 2024-10-24 15:59:36 +08:00
zhouchengrong
93c37e268a feat 新增 process lookbooks 接口
fix
2024-10-23 10:07:42 +08:00
zhouchengrong
dd0bdb16b5 feat 新增 process lookbooks 接口
fix
2024-10-22 15:43:32 +08:00
zhouchengrong
d9f35f9faa feat 新增 process lookbooks 接口
fix
2024-10-22 15:34:33 +08:00
zhouchengrong
7da300b2c9 feat 新增 process lookbooks 接口
fix
2024-10-22 15:34:10 +08:00
zhouchengrong
c8483fdc0c feat 新增 process lookbooks 接口
fix
2024-10-22 15:20:03 +08:00
zhouchengrong
6e0942ca3c feat 新增 process lookbooks 接口
fix
2024-10-22 15:11:49 +08:00
zhouchengrong
ed017fdf9d feat 新增 process lookbooks 接口
fix
2024-10-22 15:11:08 +08:00
shahaibo
61af80541b TASK:lookbook上传 2024-10-22 12:11:04 +08:00
14 changed files with 148 additions and 40 deletions

View File

@@ -1,4 +1,4 @@
FROM python:3.9
FROM python:3.12
ENV TZ=Asia/Shanghai
RUN apt update
RUN apt install -y vim
@@ -9,6 +9,7 @@ RUN apt install -y tesseract-ocr
COPY ./requirements.txt /requirements.txt
RUN pip install --upgrade pip
RUN pip install -r requirements.txt
RUN pip install --upgrade grpcio
RUN pip install gunicorn
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
RUN pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html

View File

@@ -16,11 +16,10 @@
1. 安装依赖
$ conda create -n trinity_client_mixi python=3.9 -y
$ conda create -n trinity_client_mixi python=3.12 -y
$ conda activate trinity_client_mixi
$ pip install -r requirements.txt
$ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
2. 启动服务器

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger()
router = APIRouter()
@router.post("")
@router.post("/attribute")
def attribute(request_data: AttributeModel):
logger.info(f"attribute requests is @@@@@@@@@@@:{request_data}")
service = AttributeRecognition()

View File

@@ -0,0 +1,26 @@
from fastapi import APIRouter, Query
from typing import Optional
from app.service.lookbooks.query_service import query_lookbooks_service # 引入业务逻辑
router = APIRouter()
@router.get("/query-lookbooks")
async def query_lookbooks(
tag: Optional[str] = Query(None, description="Tag to filter lookbooks"),
year: Optional[str] = Query(None, description="Year to filter lookbooks"),
n_results: int = Query(10, description="Number of results to return")
):
"""
查询向量数据库,支持按 tag 和 year 查询
:param tag: 查询过滤的标签
:param year: 查询过滤的年份
:param n_results: 返回的结果数量
:return: 查询结果
"""
try:
# 调用业务逻辑层的查询服务
result_list = await query_lookbooks_service(tag, year, n_results)
return {"status": "success", "data": result_list}
except Exception as e:
return {"status": "error", "message": str(e)}

View File

@@ -16,7 +16,7 @@ logger = logging.getLogger()
router = APIRouter()
@router.post("outfit_matcher")
@router.post("/outfit_matcher")
def outfit_matcher(request_item: OutfitMatcher):
start_time = time.time()
request_item = dict(request_item)

View File

@@ -1,18 +1,17 @@
import logging
import os
import shutil
from typing import List
import aiofiles
import tqdm
from fastapi import UploadFile, File, APIRouter, BackgroundTasks, Form
from app.service.lookbooks.service import create_image_batch_requests, process_lookbook_task # 引入服务逻辑
from app.service.lookbooks.service import process_lookbook_task # 引入服务逻辑
logger = logging.getLogger()
router = APIRouter()
@router.post("/process-lookbooks/")
@router.post("/process-lookbooks")
async def process_lookbooks(
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...),

View File

@@ -1,11 +1,13 @@
from fastapi import APIRouter
from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks
from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks, \
api_lookbooks_query
router = APIRouter()
router.include_router(api_test.router, tags=["test"], prefix="/test")
router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api/outfit_matcher")
router.include_router(api_attribute.router, tags=["attribute"], prefix="/api/attribute")
router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api/similar_match")
router.include_router(api_process_lookbooks.router, tags=["process_lookbooks"], prefix="/api/process_lookbooks")
router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api")
router.include_router(api_attribute.router, tags=["attribute"], prefix="/api")
router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api")
router.include_router(api_process_lookbooks.router, tags=["process-lookbooks"], prefix="/api")
router.include_router(api_lookbooks_query.router, tags=["query-lookbooks"], prefix="/api")

View File

@@ -16,11 +16,11 @@ router = APIRouter()
@RunTime
@router.post("similar_match")
@router.post("/similar_match")
def similar_match(request_item: SimilarMatchMItem):
try:
if request_item.result_number <= 0:
raise KeyError("result number can't be less than 0")
raise KeyError("results number can't be less than 0")
service = SimilarMatch(request_item)
search_response = service.match_features()
response_data = []
@@ -36,4 +36,4 @@ def similar_match(request_item: SimilarMatchMItem):
return {"message": "ok", "data": response_data}
except KeyError as e:
logger.warning(str(e))
return {"message": "result number can't be less than 0", "data": []}
return {"message": "results number can't be less than 0", "data": []}

View File

@@ -1,7 +1,9 @@
import logging
import os
from typing import ClassVar
from dotenv import load_dotenv
from pydantic import BaseSettings
from pydantic_settings import BaseSettings
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
logging.info(f"BASE_DIR : {BASE_DIR}")
@@ -9,14 +11,14 @@ load_dotenv(os.path.join(BASE_DIR, '.env'))
class Settings(BaseSettings):
PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
SECRET_KEY = os.getenv('SECRET_KEY', '')
API_PREFIX = ''
BACKEND_CORS_ORIGINS = ['*']
DATABASE_URL = os.getenv('SQL_DATABASE_URL', '')
PROJECT_NAME: ClassVar[str] = 'FASTAPI BASE'
SECRET_KEY: str = ''
API_PREFIX: str = ''
BACKEND_CORS_ORIGINS: list[str] = ['*']
DATABASE_URL: str = ''
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM = 'HS256'
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
SECURITY_ALGORITHM: str = 'HS256'
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
settings = Settings()

View File

@@ -1,15 +1,14 @@
# config.py
import os
# import platform
# if platform.system() == 'Linux':
# __import__('pysqlite3')
# import sys
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import platform
if platform.system() == 'Linux':
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import chromadb
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
# import tritonclient.grpc as grpcclient
# from minio import Minio
import tritonclient.grpc as grpcclient
from minio import Minio
# OpenAI settings
OPENAI_API_KEY = "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"
@@ -26,7 +25,7 @@ MINIO_SECURE = False
MINIO_ACCESS = "e8zc55mzDOh4IzRrZ9Oa"
MINIO_SECRET = "uHfqJ7UkwA1PTDGfnA44Hp9ux5YkZTkzZLjeOYhE"
MINIO_BUCKET = "test"
# MINIO_CLIENT = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
MINIO_CLIENT = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# Set environment variables
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
@@ -71,4 +70,4 @@ elif OP_SERVER == "aws":
# os.environ["OOTD_URL"] = "http://18.167.251.121:10001/ootd/ootd_dc"
os.environ["OOTD_URL"] = "https://muskox-many-bluegill.ngrok-free.app/ootd_dc"
# triton_client = grpcclient.InferenceServerClient(url=os.environ['GRPCCLIENT_URL'])
triton_client = grpcclient.InferenceServerClient(url=os.environ['GRPCCLIENT_URL'])

View File

@@ -0,0 +1,74 @@
from app.service.lookbooks.config.config import DOCUMENT_COLLECTION # 引入向量数据库
from typing import Optional
async def query_lookbooks_service(tag: Optional[str], year: Optional[str], n_results: int):
"""
查询向量数据库,支持按 tag 和 year 过滤
:param tag: 查询过滤的标签
:param year: 查询过滤的年份
:param n_results: 返回的结果数量
:return: 查询结果列表
"""
try:
# 选择一个主要过滤条件进行初步查询
primary_filter = {}
if tag:
primary_filter['tag'] = tag
elif year:
primary_filter['year'] = year
else:
primary_filter = None
# 如果没有任何条件,直接返回空列表
if not primary_filter:
return []
# 使用主过滤条件进行查询
query_result = DOCUMENT_COLLECTION.get(where=primary_filter)
# 打印 query_result 以检查数据结构是否符合预期
print(query_result) # 或者使用 logger 记录以调试
# 检查 query_result 的结构
if not isinstance(query_result, dict) or 'documents' not in query_result:
raise ValueError("Expected query_result to be a dict containing a 'documents' key")
documents = query_result['documents']
# 确保 documents 是一个列表
if not isinstance(documents, list):
raise ValueError("Expected 'documents' to be a list")
# 对文档进行进一步过滤
filtered_results = []
for item in documents:
# 检查每个 item 是否是字典,如果不是字典,可能是简单文本描述
if isinstance(item, dict):
metadata = item.get('metadata', {})
if (not year or metadata.get('year') == year) and (not tag or metadata.get('tag') == tag):
filtered_results.append(item)
elif isinstance(item, str):
# 如果 item 是字符串,假设它是描述文本,构造默认的 metadata
filtered_results.append({
"text": item,
"metadata": {}
})
else:
# 如果是其他数据类型,打印警告并跳过
print(f"Unexpected item type in documents: {type(item)}")
# 限制结果数量
limited_results = filtered_results[:n_results]
# 格式化结果
result_list = []
for item in limited_results:
result_list.append({
"description": item.get('text', ""),
"metadata": item.get('metadata', {})
})
return result_list
except Exception as e:
raise e # 抛出异常,让接口层处理

View File

@@ -45,6 +45,9 @@ def create_image_batch_requests(
# 预处理 prompt移除多余的空白和换行符
prompt = ' '.join(prompt.split())
# 创建目录(如果目录不存在)
os.makedirs(output_path, exist_ok=True)
completed_id = []
if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")):
with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f:
@@ -141,7 +144,8 @@ async def process_lookbook_task(lookbook_list, tag, year):
try:
for look_book_path in tqdm.tqdm(lookbook_list):
lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0]
output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name)
output_dir = os.path.join("fashion_documents", "lookbook", "images", lookbook_name)
# output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name)
os.makedirs(output_dir, exist_ok=True)
if not os.listdir(output_dir):
from unstructured.partition.pdf import partition_pdf
@@ -159,8 +163,10 @@ async def process_lookbook_task(lookbook_list, tag, year):
current_images = os.listdir(output_dir)
image_list.extend([os.path.join(output_dir, x) for x in current_images])
output_path = os.path.join("fashion_documents", "lookbook", "results")
# 1. 处理图片并生成批量请求
image_description_results_file = create_image_batch_requests(image_list, "fashion_documents/lookbook/results")
image_description_results_file = create_image_batch_requests(image_list, output_path)
# 2. 保存结果到向量数据库
if image_description_results_file:

View File

@@ -1,9 +1,9 @@
version: "3"
services:
trinity_mixi:
container_name: trinity_mixi
build: .
image: aidlabzcr/trinity_mixi:latest
# build: .
volumes:
- ./trinity_client_mixi:/trinity
- ./:/app
ports:
- "10100:4562"

Binary file not shown.