Compare commits

...

10 Commits

Author SHA1 Message Date
zhouchengrong
5b227e3008 feat 新增 process lookbooks 接口
fix
2024-10-24 17:35:46 +08:00
shahaibo
07e72c1ee1 TASK:lookbook上传,查询 2024-10-24 15:59:36 +08:00
zhouchengrong
93c37e268a feat 新增 process lookbooks 接口
fix
2024-10-23 10:07:42 +08:00
zhouchengrong
dd0bdb16b5 feat 新增 process lookbooks 接口
fix
2024-10-22 15:43:32 +08:00
zhouchengrong
d9f35f9faa feat 新增 process lookbooks 接口
fix
2024-10-22 15:34:33 +08:00
zhouchengrong
7da300b2c9 feat 新增 process lookbooks 接口
fix
2024-10-22 15:34:10 +08:00
zhouchengrong
c8483fdc0c feat 新增 process lookbooks 接口
fix
2024-10-22 15:20:03 +08:00
zhouchengrong
6e0942ca3c feat 新增 process lookbooks 接口
fix
2024-10-22 15:11:49 +08:00
zhouchengrong
ed017fdf9d feat 新增 process lookbooks 接口
fix
2024-10-22 15:11:08 +08:00
shahaibo
61af80541b TASK:lookbook上传 2024-10-22 12:11:04 +08:00
14 changed files with 148 additions and 40 deletions

View File

@@ -1,4 +1,4 @@
FROM python:3.9 FROM python:3.12
ENV TZ=Asia/Shanghai ENV TZ=Asia/Shanghai
RUN apt update RUN apt update
RUN apt install -y vim RUN apt install -y vim
@@ -9,6 +9,7 @@ RUN apt install -y tesseract-ocr
COPY ./requirements.txt /requirements.txt COPY ./requirements.txt /requirements.txt
RUN pip install --upgrade pip RUN pip install --upgrade pip
RUN pip install -r requirements.txt RUN pip install -r requirements.txt
RUN pip install --upgrade grpcio
RUN pip install gunicorn RUN pip install gunicorn
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
RUN pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html RUN pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html

View File

@@ -16,11 +16,10 @@
1. 安装依赖 1. 安装依赖
$ conda create -n trinity_client_mixi python=3.9 -y $ conda create -n trinity_client_mixi python=3.12 -y
$ conda activate trinity_client_mixi $ conda activate trinity_client_mixi
$ pip install -r requirements.txt $ pip install -r requirements.txt
$ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y $ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
2. 启动服务器 2. 启动服务器

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger()
router = APIRouter() router = APIRouter()
@router.post("") @router.post("/attribute")
def attribute(request_data: AttributeModel): def attribute(request_data: AttributeModel):
logger.info(f"attribute requests is @@@@@@@@@@@:{request_data}") logger.info(f"attribute requests is @@@@@@@@@@@:{request_data}")
service = AttributeRecognition() service = AttributeRecognition()

View File

@@ -0,0 +1,26 @@
from fastapi import APIRouter, Query
from typing import Optional
from app.service.lookbooks.query_service import query_lookbooks_service # 引入业务逻辑
router = APIRouter()
@router.get("/query-lookbooks")
async def query_lookbooks(
tag: Optional[str] = Query(None, description="Tag to filter lookbooks"),
year: Optional[str] = Query(None, description="Year to filter lookbooks"),
n_results: int = Query(10, description="Number of results to return")
):
"""
查询向量数据库,支持按 tag 和 year 查询
:param tag: 查询过滤的标签
:param year: 查询过滤的年份
:param n_results: 返回的结果数量
:return: 查询结果
"""
try:
# 调用业务逻辑层的查询服务
result_list = await query_lookbooks_service(tag, year, n_results)
return {"status": "success", "data": result_list}
except Exception as e:
return {"status": "error", "message": str(e)}

View File

@@ -16,7 +16,7 @@ logger = logging.getLogger()
router = APIRouter() router = APIRouter()
@router.post("outfit_matcher") @router.post("/outfit_matcher")
def outfit_matcher(request_item: OutfitMatcher): def outfit_matcher(request_item: OutfitMatcher):
start_time = time.time() start_time = time.time()
request_item = dict(request_item) request_item = dict(request_item)

View File

@@ -1,18 +1,17 @@
import logging import logging
import os import os
import shutil
from typing import List from typing import List
import aiofiles import aiofiles
import tqdm
from fastapi import UploadFile, File, APIRouter, BackgroundTasks, Form from fastapi import UploadFile, File, APIRouter, BackgroundTasks, Form
from app.service.lookbooks.service import create_image_batch_requests, process_lookbook_task # 引入服务逻辑
from app.service.lookbooks.service import process_lookbook_task # 引入服务逻辑
logger = logging.getLogger() logger = logging.getLogger()
router = APIRouter() router = APIRouter()
@router.post("/process-lookbooks/") @router.post("/process-lookbooks")
async def process_lookbooks( async def process_lookbooks(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...), files: List[UploadFile] = File(...),

View File

@@ -1,11 +1,13 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks, \
api_lookbooks_query
router = APIRouter() router = APIRouter()
router.include_router(api_test.router, tags=["test"], prefix="/test") router.include_router(api_test.router, tags=["test"], prefix="/test")
router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api/outfit_matcher") router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api")
router.include_router(api_attribute.router, tags=["attribute"], prefix="/api/attribute") router.include_router(api_attribute.router, tags=["attribute"], prefix="/api")
router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api/similar_match") router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api")
router.include_router(api_process_lookbooks.router, tags=["process_lookbooks"], prefix="/api/process_lookbooks") router.include_router(api_process_lookbooks.router, tags=["process-lookbooks"], prefix="/api")
router.include_router(api_lookbooks_query.router, tags=["query-lookbooks"], prefix="/api")

View File

@@ -16,11 +16,11 @@ router = APIRouter()
@RunTime @RunTime
@router.post("similar_match") @router.post("/similar_match")
def similar_match(request_item: SimilarMatchMItem): def similar_match(request_item: SimilarMatchMItem):
try: try:
if request_item.result_number <= 0: if request_item.result_number <= 0:
raise KeyError("result number can't be less than 0") raise KeyError("results number can't be less than 0")
service = SimilarMatch(request_item) service = SimilarMatch(request_item)
search_response = service.match_features() search_response = service.match_features()
response_data = [] response_data = []
@@ -36,4 +36,4 @@ def similar_match(request_item: SimilarMatchMItem):
return {"message": "ok", "data": response_data} return {"message": "ok", "data": response_data}
except KeyError as e: except KeyError as e:
logger.warning(str(e)) logger.warning(str(e))
return {"message": "result number can't be less than 0", "data": []} return {"message": "results number can't be less than 0", "data": []}

View File

@@ -1,7 +1,9 @@
import logging import logging
import os import os
from typing import ClassVar
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseSettings from pydantic_settings import BaseSettings
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
logging.info(f"BASE_DIR : {BASE_DIR}") logging.info(f"BASE_DIR : {BASE_DIR}")
@@ -9,14 +11,14 @@ load_dotenv(os.path.join(BASE_DIR, '.env'))
class Settings(BaseSettings): class Settings(BaseSettings):
PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE') PROJECT_NAME: ClassVar[str] = 'FASTAPI BASE'
SECRET_KEY = os.getenv('SECRET_KEY', '') SECRET_KEY: str = ''
API_PREFIX = '' API_PREFIX: str = ''
BACKEND_CORS_ORIGINS = ['*'] BACKEND_CORS_ORIGINS: list[str] = ['*']
DATABASE_URL = os.getenv('SQL_DATABASE_URL', '') DATABASE_URL: str = ''
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM = 'HS256' SECURITY_ALGORITHM: str = 'HS256'
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
settings = Settings() settings = Settings()

View File

@@ -1,15 +1,14 @@
# config.py
import os import os
# import platform import platform
# if platform.system() == 'Linux': if platform.system() == 'Linux':
# __import__('pysqlite3') __import__('pysqlite3')
# import sys import sys
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import chromadb import chromadb
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma from langchain_chroma import Chroma
# import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
# from minio import Minio from minio import Minio
# OpenAI settings # OpenAI settings
OPENAI_API_KEY = "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" OPENAI_API_KEY = "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"
@@ -26,7 +25,7 @@ MINIO_SECURE = False
MINIO_ACCESS = "e8zc55mzDOh4IzRrZ9Oa" MINIO_ACCESS = "e8zc55mzDOh4IzRrZ9Oa"
MINIO_SECRET = "uHfqJ7UkwA1PTDGfnA44Hp9ux5YkZTkzZLjeOYhE" MINIO_SECRET = "uHfqJ7UkwA1PTDGfnA44Hp9ux5YkZTkzZLjeOYhE"
MINIO_BUCKET = "test" MINIO_BUCKET = "test"
# MINIO_CLIENT = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) MINIO_CLIENT = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# Set environment variables # Set environment variables
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
@@ -71,4 +70,4 @@ elif OP_SERVER == "aws":
# os.environ["OOTD_URL"] = "http://18.167.251.121:10001/ootd/ootd_dc" # os.environ["OOTD_URL"] = "http://18.167.251.121:10001/ootd/ootd_dc"
os.environ["OOTD_URL"] = "https://muskox-many-bluegill.ngrok-free.app/ootd_dc" os.environ["OOTD_URL"] = "https://muskox-many-bluegill.ngrok-free.app/ootd_dc"
# triton_client = grpcclient.InferenceServerClient(url=os.environ['GRPCCLIENT_URL']) triton_client = grpcclient.InferenceServerClient(url=os.environ['GRPCCLIENT_URL'])

View File

@@ -0,0 +1,74 @@
from app.service.lookbooks.config.config import DOCUMENT_COLLECTION # 引入向量数据库
from typing import Optional
async def query_lookbooks_service(tag: Optional[str], year: Optional[str], n_results: int):
"""
查询向量数据库,支持按 tag 和 year 过滤
:param tag: 查询过滤的标签
:param year: 查询过滤的年份
:param n_results: 返回的结果数量
:return: 查询结果列表
"""
try:
# 选择一个主要过滤条件进行初步查询
primary_filter = {}
if tag:
primary_filter['tag'] = tag
elif year:
primary_filter['year'] = year
else:
primary_filter = None
# 如果没有任何条件,直接返回空列表
if not primary_filter:
return []
# 使用主过滤条件进行查询
query_result = DOCUMENT_COLLECTION.get(where=primary_filter)
# 打印 query_result 以检查数据结构是否符合预期
print(query_result) # 或者使用 logger 记录以调试
# 检查 query_result 的结构
if not isinstance(query_result, dict) or 'documents' not in query_result:
raise ValueError("Expected query_result to be a dict containing a 'documents' key")
documents = query_result['documents']
# 确保 documents 是一个列表
if not isinstance(documents, list):
raise ValueError("Expected 'documents' to be a list")
# 对文档进行进一步过滤
filtered_results = []
for item in documents:
# 检查每个 item 是否是字典,如果不是字典,可能是简单文本描述
if isinstance(item, dict):
metadata = item.get('metadata', {})
if (not year or metadata.get('year') == year) and (not tag or metadata.get('tag') == tag):
filtered_results.append(item)
elif isinstance(item, str):
# 如果 item 是字符串,假设它是描述文本,构造默认的 metadata
filtered_results.append({
"text": item,
"metadata": {}
})
else:
# 如果是其他数据类型,打印警告并跳过
print(f"Unexpected item type in documents: {type(item)}")
# 限制结果数量
limited_results = filtered_results[:n_results]
# 格式化结果
result_list = []
for item in limited_results:
result_list.append({
"description": item.get('text', ""),
"metadata": item.get('metadata', {})
})
return result_list
except Exception as e:
raise e # 抛出异常,让接口层处理

View File

@@ -45,6 +45,9 @@ def create_image_batch_requests(
# 预处理 prompt移除多余的空白和换行符 # 预处理 prompt移除多余的空白和换行符
prompt = ' '.join(prompt.split()) prompt = ' '.join(prompt.split())
# 创建目录(如果目录不存在)
os.makedirs(output_path, exist_ok=True)
completed_id = [] completed_id = []
if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")): if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")):
with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f: with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f:
@@ -141,7 +144,8 @@ async def process_lookbook_task(lookbook_list, tag, year):
try: try:
for look_book_path in tqdm.tqdm(lookbook_list): for look_book_path in tqdm.tqdm(lookbook_list):
lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0] lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0]
output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name) output_dir = os.path.join("fashion_documents", "lookbook", "images", lookbook_name)
# output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
if not os.listdir(output_dir): if not os.listdir(output_dir):
from unstructured.partition.pdf import partition_pdf from unstructured.partition.pdf import partition_pdf
@@ -159,8 +163,10 @@ async def process_lookbook_task(lookbook_list, tag, year):
current_images = os.listdir(output_dir) current_images = os.listdir(output_dir)
image_list.extend([os.path.join(output_dir, x) for x in current_images]) image_list.extend([os.path.join(output_dir, x) for x in current_images])
output_path = os.path.join("fashion_documents", "lookbook", "results")
# 1. 处理图片并生成批量请求 # 1. 处理图片并生成批量请求
image_description_results_file = create_image_batch_requests(image_list, "fashion_documents/lookbook/results") image_description_results_file = create_image_batch_requests(image_list, output_path)
# 2. 保存结果到向量数据库 # 2. 保存结果到向量数据库
if image_description_results_file: if image_description_results_file:

View File

@@ -1,9 +1,9 @@
version: "3" version: "3"
services: services:
trinity_mixi: trinity_mixi:
container_name: trinity_mixi image: aidlabzcr/trinity_mixi:latest
build: . # build: .
volumes: volumes:
- ./trinity_client_mixi:/trinity - ./:/app
ports: ports:
- "10100:4562" - "10100:4562"

Binary file not shown.