Compare commits
10 Commits
edbce4ac16
...
5b227e3008
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b227e3008 | ||
|
|
07e72c1ee1 | ||
|
|
93c37e268a | ||
|
|
dd0bdb16b5 | ||
|
|
d9f35f9faa | ||
|
|
7da300b2c9 | ||
|
|
c8483fdc0c | ||
|
|
6e0942ca3c | ||
|
|
ed017fdf9d | ||
|
|
61af80541b |
@@ -1,4 +1,4 @@
|
|||||||
FROM python:3.9
|
FROM python:3.12
|
||||||
ENV TZ=Asia/Shanghai
|
ENV TZ=Asia/Shanghai
|
||||||
RUN apt update
|
RUN apt update
|
||||||
RUN apt install -y vim
|
RUN apt install -y vim
|
||||||
@@ -9,6 +9,7 @@ RUN apt install -y tesseract-ocr
|
|||||||
COPY ./requirements.txt /requirements.txt
|
COPY ./requirements.txt /requirements.txt
|
||||||
RUN pip install --upgrade pip
|
RUN pip install --upgrade pip
|
||||||
RUN pip install -r requirements.txt
|
RUN pip install -r requirements.txt
|
||||||
|
RUN pip install --upgrade grpcio
|
||||||
RUN pip install gunicorn
|
RUN pip install gunicorn
|
||||||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||||
RUN pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
RUN pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||||
|
|||||||
@@ -16,11 +16,10 @@
|
|||||||
|
|
||||||
1. 安装依赖
|
1. 安装依赖
|
||||||
|
|
||||||
$ conda create -n trinity_client_mixi python=3.9 -y
|
$ conda create -n trinity_client_mixi python=3.12 -y
|
||||||
$ conda activate trinity_client_mixi
|
$ conda activate trinity_client_mixi
|
||||||
$ pip install -r requirements.txt
|
$ pip install -r requirements.txt
|
||||||
$ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
|
$ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
|
||||||
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
|
||||||
|
|
||||||
|
|
||||||
2. 启动服务器
|
2. 启动服务器
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ logger = logging.getLogger()
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("/attribute")
|
||||||
def attribute(request_data: AttributeModel):
|
def attribute(request_data: AttributeModel):
|
||||||
logger.info(f"attribute requests is @@@@@@@@@@@:{request_data}")
|
logger.info(f"attribute requests is @@@@@@@@@@@:{request_data}")
|
||||||
service = AttributeRecognition()
|
service = AttributeRecognition()
|
||||||
|
|||||||
26
app/api/api_lookbooks_query.py
Normal file
26
app/api/api_lookbooks_query.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from fastapi import APIRouter, Query
|
||||||
|
from typing import Optional
|
||||||
|
from app.service.lookbooks.query_service import query_lookbooks_service # 引入业务逻辑
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/query-lookbooks")
|
||||||
|
async def query_lookbooks(
|
||||||
|
tag: Optional[str] = Query(None, description="Tag to filter lookbooks"),
|
||||||
|
year: Optional[str] = Query(None, description="Year to filter lookbooks"),
|
||||||
|
n_results: int = Query(10, description="Number of results to return")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
查询向量数据库,支持按 tag 和 year 查询
|
||||||
|
:param tag: 查询过滤的标签
|
||||||
|
:param year: 查询过滤的年份
|
||||||
|
:param n_results: 返回的结果数量
|
||||||
|
:return: 查询结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 调用业务逻辑层的查询服务
|
||||||
|
result_list = await query_lookbooks_service(tag, year, n_results)
|
||||||
|
return {"status": "success", "data": result_list}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
@@ -16,7 +16,7 @@ logger = logging.getLogger()
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("outfit_matcher")
|
@router.post("/outfit_matcher")
|
||||||
def outfit_matcher(request_item: OutfitMatcher):
|
def outfit_matcher(request_item: OutfitMatcher):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
request_item = dict(request_item)
|
request_item = dict(request_item)
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import tqdm
|
|
||||||
from fastapi import UploadFile, File, APIRouter, BackgroundTasks, Form
|
from fastapi import UploadFile, File, APIRouter, BackgroundTasks, Form
|
||||||
from app.service.lookbooks.service import create_image_batch_requests, process_lookbook_task # 引入服务逻辑
|
|
||||||
|
from app.service.lookbooks.service import process_lookbook_task # 引入服务逻辑
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/process-lookbooks/")
|
@router.post("/process-lookbooks")
|
||||||
async def process_lookbooks(
|
async def process_lookbooks(
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
files: List[UploadFile] = File(...),
|
files: List[UploadFile] = File(...),
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks
|
from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks, \
|
||||||
|
api_lookbooks_query
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
router.include_router(api_test.router, tags=["test"], prefix="/test")
|
router.include_router(api_test.router, tags=["test"], prefix="/test")
|
||||||
router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api/outfit_matcher")
|
router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api")
|
||||||
router.include_router(api_attribute.router, tags=["attribute"], prefix="/api/attribute")
|
router.include_router(api_attribute.router, tags=["attribute"], prefix="/api")
|
||||||
router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api/similar_match")
|
router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api")
|
||||||
router.include_router(api_process_lookbooks.router, tags=["process_lookbooks"], prefix="/api/process_lookbooks")
|
router.include_router(api_process_lookbooks.router, tags=["process-lookbooks"], prefix="/api")
|
||||||
|
router.include_router(api_lookbooks_query.router, tags=["query-lookbooks"], prefix="/api")
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
@RunTime
|
@RunTime
|
||||||
@router.post("similar_match")
|
@router.post("/similar_match")
|
||||||
def similar_match(request_item: SimilarMatchMItem):
|
def similar_match(request_item: SimilarMatchMItem):
|
||||||
try:
|
try:
|
||||||
if request_item.result_number <= 0:
|
if request_item.result_number <= 0:
|
||||||
raise KeyError("result number can't be less than 0")
|
raise KeyError("results number can't be less than 0")
|
||||||
service = SimilarMatch(request_item)
|
service = SimilarMatch(request_item)
|
||||||
search_response = service.match_features()
|
search_response = service.match_features()
|
||||||
response_data = []
|
response_data = []
|
||||||
@@ -36,4 +36,4 @@ def similar_match(request_item: SimilarMatchMItem):
|
|||||||
return {"message": "ok", "data": response_data}
|
return {"message": "ok", "data": response_data}
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
return {"message": "result number can't be less than 0", "data": []}
|
return {"message": "results number can't be less than 0", "data": []}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
||||||
logging.info(f"BASE_DIR : {BASE_DIR}")
|
logging.info(f"BASE_DIR : {BASE_DIR}")
|
||||||
@@ -9,14 +11,14 @@ load_dotenv(os.path.join(BASE_DIR, '.env'))
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
|
PROJECT_NAME: ClassVar[str] = 'FASTAPI BASE'
|
||||||
SECRET_KEY = os.getenv('SECRET_KEY', '')
|
SECRET_KEY: str = ''
|
||||||
API_PREFIX = ''
|
API_PREFIX: str = ''
|
||||||
BACKEND_CORS_ORIGINS = ['*']
|
BACKEND_CORS_ORIGINS: list[str] = ['*']
|
||||||
DATABASE_URL = os.getenv('SQL_DATABASE_URL', '')
|
DATABASE_URL: str = ''
|
||||||
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
||||||
SECURITY_ALGORITHM = 'HS256'
|
SECURITY_ALGORITHM: str = 'HS256'
|
||||||
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
|
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
# config.py
|
|
||||||
import os
|
import os
|
||||||
# import platform
|
import platform
|
||||||
# if platform.system() == 'Linux':
|
if platform.system() == 'Linux':
|
||||||
# __import__('pysqlite3')
|
__import__('pysqlite3')
|
||||||
# import sys
|
import sys
|
||||||
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
||||||
import chromadb
|
import chromadb
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
from langchain_chroma import Chroma
|
from langchain_chroma import Chroma
|
||||||
# import tritonclient.grpc as grpcclient
|
import tritonclient.grpc as grpcclient
|
||||||
# from minio import Minio
|
from minio import Minio
|
||||||
|
|
||||||
# OpenAI settings
|
# OpenAI settings
|
||||||
OPENAI_API_KEY = "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"
|
OPENAI_API_KEY = "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"
|
||||||
@@ -26,7 +25,7 @@ MINIO_SECURE = False
|
|||||||
MINIO_ACCESS = "e8zc55mzDOh4IzRrZ9Oa"
|
MINIO_ACCESS = "e8zc55mzDOh4IzRrZ9Oa"
|
||||||
MINIO_SECRET = "uHfqJ7UkwA1PTDGfnA44Hp9ux5YkZTkzZLjeOYhE"
|
MINIO_SECRET = "uHfqJ7UkwA1PTDGfnA44Hp9ux5YkZTkzZLjeOYhE"
|
||||||
MINIO_BUCKET = "test"
|
MINIO_BUCKET = "test"
|
||||||
# MINIO_CLIENT = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
MINIO_CLIENT = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
||||||
@@ -71,4 +70,4 @@ elif OP_SERVER == "aws":
|
|||||||
# os.environ["OOTD_URL"] = "http://18.167.251.121:10001/ootd/ootd_dc"
|
# os.environ["OOTD_URL"] = "http://18.167.251.121:10001/ootd/ootd_dc"
|
||||||
os.environ["OOTD_URL"] = "https://muskox-many-bluegill.ngrok-free.app/ootd_dc"
|
os.environ["OOTD_URL"] = "https://muskox-many-bluegill.ngrok-free.app/ootd_dc"
|
||||||
|
|
||||||
# triton_client = grpcclient.InferenceServerClient(url=os.environ['GRPCCLIENT_URL'])
|
triton_client = grpcclient.InferenceServerClient(url=os.environ['GRPCCLIENT_URL'])
|
||||||
|
|||||||
74
app/service/lookbooks/query_service.py
Normal file
74
app/service/lookbooks/query_service.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from app.service.lookbooks.config.config import DOCUMENT_COLLECTION # 引入向量数据库
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
async def query_lookbooks_service(tag: Optional[str], year: Optional[str], n_results: int):
|
||||||
|
"""
|
||||||
|
查询向量数据库,支持按 tag 和 year 过滤
|
||||||
|
:param tag: 查询过滤的标签
|
||||||
|
:param year: 查询过滤的年份
|
||||||
|
:param n_results: 返回的结果数量
|
||||||
|
:return: 查询结果列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 选择一个主要过滤条件进行初步查询
|
||||||
|
primary_filter = {}
|
||||||
|
if tag:
|
||||||
|
primary_filter['tag'] = tag
|
||||||
|
elif year:
|
||||||
|
primary_filter['year'] = year
|
||||||
|
else:
|
||||||
|
primary_filter = None
|
||||||
|
|
||||||
|
# 如果没有任何条件,直接返回空列表
|
||||||
|
if not primary_filter:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 使用主过滤条件进行查询
|
||||||
|
query_result = DOCUMENT_COLLECTION.get(where=primary_filter)
|
||||||
|
|
||||||
|
# 打印 query_result 以检查数据结构是否符合预期
|
||||||
|
print(query_result) # 或者使用 logger 记录以调试
|
||||||
|
|
||||||
|
# 检查 query_result 的结构
|
||||||
|
if not isinstance(query_result, dict) or 'documents' not in query_result:
|
||||||
|
raise ValueError("Expected query_result to be a dict containing a 'documents' key")
|
||||||
|
|
||||||
|
documents = query_result['documents']
|
||||||
|
|
||||||
|
# 确保 documents 是一个列表
|
||||||
|
if not isinstance(documents, list):
|
||||||
|
raise ValueError("Expected 'documents' to be a list")
|
||||||
|
|
||||||
|
# 对文档进行进一步过滤
|
||||||
|
filtered_results = []
|
||||||
|
for item in documents:
|
||||||
|
# 检查每个 item 是否是字典,如果不是字典,可能是简单文本描述
|
||||||
|
if isinstance(item, dict):
|
||||||
|
metadata = item.get('metadata', {})
|
||||||
|
if (not year or metadata.get('year') == year) and (not tag or metadata.get('tag') == tag):
|
||||||
|
filtered_results.append(item)
|
||||||
|
elif isinstance(item, str):
|
||||||
|
# 如果 item 是字符串,假设它是描述文本,构造默认的 metadata
|
||||||
|
filtered_results.append({
|
||||||
|
"text": item,
|
||||||
|
"metadata": {}
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# 如果是其他数据类型,打印警告并跳过
|
||||||
|
print(f"Unexpected item type in documents: {type(item)}")
|
||||||
|
|
||||||
|
# 限制结果数量
|
||||||
|
limited_results = filtered_results[:n_results]
|
||||||
|
|
||||||
|
# 格式化结果
|
||||||
|
result_list = []
|
||||||
|
for item in limited_results:
|
||||||
|
result_list.append({
|
||||||
|
"description": item.get('text', ""),
|
||||||
|
"metadata": item.get('metadata', {})
|
||||||
|
})
|
||||||
|
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e # 抛出异常,让接口层处理
|
||||||
@@ -45,6 +45,9 @@ def create_image_batch_requests(
|
|||||||
# 预处理 prompt:移除多余的空白和换行符
|
# 预处理 prompt:移除多余的空白和换行符
|
||||||
prompt = ' '.join(prompt.split())
|
prompt = ' '.join(prompt.split())
|
||||||
|
|
||||||
|
# 创建目录(如果目录不存在)
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
completed_id = []
|
completed_id = []
|
||||||
if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")):
|
if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")):
|
||||||
with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f:
|
with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f:
|
||||||
@@ -141,7 +144,8 @@ async def process_lookbook_task(lookbook_list, tag, year):
|
|||||||
try:
|
try:
|
||||||
for look_book_path in tqdm.tqdm(lookbook_list):
|
for look_book_path in tqdm.tqdm(lookbook_list):
|
||||||
lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0]
|
lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0]
|
||||||
output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name)
|
output_dir = os.path.join("fashion_documents", "lookbook", "images", lookbook_name)
|
||||||
|
# output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
if not os.listdir(output_dir):
|
if not os.listdir(output_dir):
|
||||||
from unstructured.partition.pdf import partition_pdf
|
from unstructured.partition.pdf import partition_pdf
|
||||||
@@ -159,8 +163,10 @@ async def process_lookbook_task(lookbook_list, tag, year):
|
|||||||
current_images = os.listdir(output_dir)
|
current_images = os.listdir(output_dir)
|
||||||
image_list.extend([os.path.join(output_dir, x) for x in current_images])
|
image_list.extend([os.path.join(output_dir, x) for x in current_images])
|
||||||
|
|
||||||
|
output_path = os.path.join("fashion_documents", "lookbook", "results")
|
||||||
|
|
||||||
# 1. 处理图片并生成批量请求
|
# 1. 处理图片并生成批量请求
|
||||||
image_description_results_file = create_image_batch_requests(image_list, "fashion_documents/lookbook/results")
|
image_description_results_file = create_image_batch_requests(image_list, output_path)
|
||||||
|
|
||||||
# 2. 保存结果到向量数据库
|
# 2. 保存结果到向量数据库
|
||||||
if image_description_results_file:
|
if image_description_results_file:
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
version: "3"
|
version: "3"
|
||||||
services:
|
services:
|
||||||
trinity_mixi:
|
trinity_mixi:
|
||||||
container_name: trinity_mixi
|
image: aidlabzcr/trinity_mixi:latest
|
||||||
build: .
|
# build: .
|
||||||
volumes:
|
volumes:
|
||||||
- ./trinity_client_mixi:/trinity
|
- ./:/app
|
||||||
ports:
|
ports:
|
||||||
- "10100:4562"
|
- "10100:4562"
|
||||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Reference in New Issue
Block a user