Compare commits
10 Commits
edbce4ac16
...
5b227e3008
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b227e3008 | ||
|
|
07e72c1ee1 | ||
|
|
93c37e268a | ||
|
|
dd0bdb16b5 | ||
|
|
d9f35f9faa | ||
|
|
7da300b2c9 | ||
|
|
c8483fdc0c | ||
|
|
6e0942ca3c | ||
|
|
ed017fdf9d | ||
|
|
61af80541b |
@@ -1,4 +1,4 @@
|
||||
FROM python:3.9
|
||||
FROM python:3.12
|
||||
ENV TZ=Asia/Shanghai
|
||||
RUN apt update
|
||||
RUN apt install -y vim
|
||||
@@ -9,6 +9,7 @@ RUN apt install -y tesseract-ocr
|
||||
COPY ./requirements.txt /requirements.txt
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install -r requirements.txt
|
||||
RUN pip install --upgrade grpcio
|
||||
RUN pip install gunicorn
|
||||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
RUN pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||
|
||||
@@ -16,11 +16,10 @@
|
||||
|
||||
1. 安装依赖
|
||||
|
||||
$ conda create -n trinity_client_mixi python=3.9 -y
|
||||
$ conda create -n trinity_client_mixi python=3.12 -y
|
||||
$ conda activate trinity_client_mixi
|
||||
$ pip install -r requirements.txt
|
||||
$ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
|
||||
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||
|
||||
|
||||
2. 启动服务器
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("")
|
||||
@router.post("/attribute")
|
||||
def attribute(request_data: AttributeModel):
|
||||
logger.info(f"attribute requests is @@@@@@@@@@@:{request_data}")
|
||||
service = AttributeRecognition()
|
||||
|
||||
26
app/api/api_lookbooks_query.py
Normal file
26
app/api/api_lookbooks_query.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from fastapi import APIRouter, Query
|
||||
from typing import Optional
|
||||
from app.service.lookbooks.query_service import query_lookbooks_service # 引入业务逻辑
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/query-lookbooks")
|
||||
async def query_lookbooks(
|
||||
tag: Optional[str] = Query(None, description="Tag to filter lookbooks"),
|
||||
year: Optional[str] = Query(None, description="Year to filter lookbooks"),
|
||||
n_results: int = Query(10, description="Number of results to return")
|
||||
):
|
||||
"""
|
||||
查询向量数据库,支持按 tag 和 year 查询
|
||||
:param tag: 查询过滤的标签
|
||||
:param year: 查询过滤的年份
|
||||
:param n_results: 返回的结果数量
|
||||
:return: 查询结果
|
||||
"""
|
||||
try:
|
||||
# 调用业务逻辑层的查询服务
|
||||
result_list = await query_lookbooks_service(tag, year, n_results)
|
||||
return {"status": "success", "data": result_list}
|
||||
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
@@ -16,7 +16,7 @@ logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("outfit_matcher")
|
||||
@router.post("/outfit_matcher")
|
||||
def outfit_matcher(request_item: OutfitMatcher):
|
||||
start_time = time.time()
|
||||
request_item = dict(request_item)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import List
|
||||
|
||||
import aiofiles
|
||||
import tqdm
|
||||
from fastapi import UploadFile, File, APIRouter, BackgroundTasks, Form
|
||||
from app.service.lookbooks.service import create_image_batch_requests, process_lookbook_task # 引入服务逻辑
|
||||
|
||||
from app.service.lookbooks.service import process_lookbook_task # 引入服务逻辑
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/process-lookbooks/")
|
||||
@router.post("/process-lookbooks")
|
||||
async def process_lookbooks(
|
||||
background_tasks: BackgroundTasks,
|
||||
files: List[UploadFile] = File(...),
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks
|
||||
from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks, \
|
||||
api_lookbooks_query
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(api_test.router, tags=["test"], prefix="/test")
|
||||
router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api/outfit_matcher")
|
||||
router.include_router(api_attribute.router, tags=["attribute"], prefix="/api/attribute")
|
||||
router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api/similar_match")
|
||||
router.include_router(api_process_lookbooks.router, tags=["process_lookbooks"], prefix="/api/process_lookbooks")
|
||||
router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api")
|
||||
router.include_router(api_attribute.router, tags=["attribute"], prefix="/api")
|
||||
router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api")
|
||||
router.include_router(api_process_lookbooks.router, tags=["process-lookbooks"], prefix="/api")
|
||||
router.include_router(api_lookbooks_query.router, tags=["query-lookbooks"], prefix="/api")
|
||||
|
||||
@@ -16,11 +16,11 @@ router = APIRouter()
|
||||
|
||||
|
||||
@RunTime
|
||||
@router.post("similar_match")
|
||||
@router.post("/similar_match")
|
||||
def similar_match(request_item: SimilarMatchMItem):
|
||||
try:
|
||||
if request_item.result_number <= 0:
|
||||
raise KeyError("result number can't be less than 0")
|
||||
raise KeyError("results number can't be less than 0")
|
||||
service = SimilarMatch(request_item)
|
||||
search_response = service.match_features()
|
||||
response_data = []
|
||||
@@ -36,4 +36,4 @@ def similar_match(request_item: SimilarMatchMItem):
|
||||
return {"message": "ok", "data": response_data}
|
||||
except KeyError as e:
|
||||
logger.warning(str(e))
|
||||
return {"message": "result number can't be less than 0", "data": []}
|
||||
return {"message": "results number can't be less than 0", "data": []}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import ClassVar
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
||||
logging.info(f"BASE_DIR : {BASE_DIR}")
|
||||
@@ -9,14 +11,14 @@ load_dotenv(os.path.join(BASE_DIR, '.env'))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
|
||||
SECRET_KEY = os.getenv('SECRET_KEY', '')
|
||||
API_PREFIX = ''
|
||||
BACKEND_CORS_ORIGINS = ['*']
|
||||
DATABASE_URL = os.getenv('SQL_DATABASE_URL', '')
|
||||
PROJECT_NAME: ClassVar[str] = 'FASTAPI BASE'
|
||||
SECRET_KEY: str = ''
|
||||
API_PREFIX: str = ''
|
||||
BACKEND_CORS_ORIGINS: list[str] = ['*']
|
||||
DATABASE_URL: str = ''
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
||||
SECURITY_ALGORITHM = 'HS256'
|
||||
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
SECURITY_ALGORITHM: str = 'HS256'
|
||||
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
# config.py
|
||||
import os
|
||||
# import platform
|
||||
# if platform.system() == 'Linux':
|
||||
# __import__('pysqlite3')
|
||||
# import sys
|
||||
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
||||
import platform
|
||||
if platform.system() == 'Linux':
|
||||
__import__('pysqlite3')
|
||||
import sys
|
||||
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
||||
import chromadb
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_chroma import Chroma
|
||||
# import tritonclient.grpc as grpcclient
|
||||
# from minio import Minio
|
||||
import tritonclient.grpc as grpcclient
|
||||
from minio import Minio
|
||||
|
||||
# OpenAI settings
|
||||
OPENAI_API_KEY = "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"
|
||||
@@ -26,7 +25,7 @@ MINIO_SECURE = False
|
||||
MINIO_ACCESS = "e8zc55mzDOh4IzRrZ9Oa"
|
||||
MINIO_SECRET = "uHfqJ7UkwA1PTDGfnA44Hp9ux5YkZTkzZLjeOYhE"
|
||||
MINIO_BUCKET = "test"
|
||||
# MINIO_CLIENT = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
MINIO_CLIENT = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
# Set environment variables
|
||||
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
||||
@@ -71,4 +70,4 @@ elif OP_SERVER == "aws":
|
||||
# os.environ["OOTD_URL"] = "http://18.167.251.121:10001/ootd/ootd_dc"
|
||||
os.environ["OOTD_URL"] = "https://muskox-many-bluegill.ngrok-free.app/ootd_dc"
|
||||
|
||||
# triton_client = grpcclient.InferenceServerClient(url=os.environ['GRPCCLIENT_URL'])
|
||||
triton_client = grpcclient.InferenceServerClient(url=os.environ['GRPCCLIENT_URL'])
|
||||
|
||||
74
app/service/lookbooks/query_service.py
Normal file
74
app/service/lookbooks/query_service.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from app.service.lookbooks.config.config import DOCUMENT_COLLECTION # 引入向量数据库
|
||||
from typing import Optional
|
||||
|
||||
async def query_lookbooks_service(tag: Optional[str], year: Optional[str], n_results: int):
|
||||
"""
|
||||
查询向量数据库,支持按 tag 和 year 过滤
|
||||
:param tag: 查询过滤的标签
|
||||
:param year: 查询过滤的年份
|
||||
:param n_results: 返回的结果数量
|
||||
:return: 查询结果列表
|
||||
"""
|
||||
try:
|
||||
# 选择一个主要过滤条件进行初步查询
|
||||
primary_filter = {}
|
||||
if tag:
|
||||
primary_filter['tag'] = tag
|
||||
elif year:
|
||||
primary_filter['year'] = year
|
||||
else:
|
||||
primary_filter = None
|
||||
|
||||
# 如果没有任何条件,直接返回空列表
|
||||
if not primary_filter:
|
||||
return []
|
||||
|
||||
# 使用主过滤条件进行查询
|
||||
query_result = DOCUMENT_COLLECTION.get(where=primary_filter)
|
||||
|
||||
# 打印 query_result 以检查数据结构是否符合预期
|
||||
print(query_result) # 或者使用 logger 记录以调试
|
||||
|
||||
# 检查 query_result 的结构
|
||||
if not isinstance(query_result, dict) or 'documents' not in query_result:
|
||||
raise ValueError("Expected query_result to be a dict containing a 'documents' key")
|
||||
|
||||
documents = query_result['documents']
|
||||
|
||||
# 确保 documents 是一个列表
|
||||
if not isinstance(documents, list):
|
||||
raise ValueError("Expected 'documents' to be a list")
|
||||
|
||||
# 对文档进行进一步过滤
|
||||
filtered_results = []
|
||||
for item in documents:
|
||||
# 检查每个 item 是否是字典,如果不是字典,可能是简单文本描述
|
||||
if isinstance(item, dict):
|
||||
metadata = item.get('metadata', {})
|
||||
if (not year or metadata.get('year') == year) and (not tag or metadata.get('tag') == tag):
|
||||
filtered_results.append(item)
|
||||
elif isinstance(item, str):
|
||||
# 如果 item 是字符串,假设它是描述文本,构造默认的 metadata
|
||||
filtered_results.append({
|
||||
"text": item,
|
||||
"metadata": {}
|
||||
})
|
||||
else:
|
||||
# 如果是其他数据类型,打印警告并跳过
|
||||
print(f"Unexpected item type in documents: {type(item)}")
|
||||
|
||||
# 限制结果数量
|
||||
limited_results = filtered_results[:n_results]
|
||||
|
||||
# 格式化结果
|
||||
result_list = []
|
||||
for item in limited_results:
|
||||
result_list.append({
|
||||
"description": item.get('text', ""),
|
||||
"metadata": item.get('metadata', {})
|
||||
})
|
||||
|
||||
return result_list
|
||||
|
||||
except Exception as e:
|
||||
raise e # 抛出异常,让接口层处理
|
||||
@@ -45,6 +45,9 @@ def create_image_batch_requests(
|
||||
# 预处理 prompt:移除多余的空白和换行符
|
||||
prompt = ' '.join(prompt.split())
|
||||
|
||||
# 创建目录(如果目录不存在)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
completed_id = []
|
||||
if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")):
|
||||
with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f:
|
||||
@@ -141,7 +144,8 @@ async def process_lookbook_task(lookbook_list, tag, year):
|
||||
try:
|
||||
for look_book_path in tqdm.tqdm(lookbook_list):
|
||||
lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0]
|
||||
output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name)
|
||||
output_dir = os.path.join("fashion_documents", "lookbook", "images", lookbook_name)
|
||||
# output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if not os.listdir(output_dir):
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
@@ -159,8 +163,10 @@ async def process_lookbook_task(lookbook_list, tag, year):
|
||||
current_images = os.listdir(output_dir)
|
||||
image_list.extend([os.path.join(output_dir, x) for x in current_images])
|
||||
|
||||
output_path = os.path.join("fashion_documents", "lookbook", "results")
|
||||
|
||||
# 1. 处理图片并生成批量请求
|
||||
image_description_results_file = create_image_batch_requests(image_list, "fashion_documents/lookbook/results")
|
||||
image_description_results_file = create_image_batch_requests(image_list, output_path)
|
||||
|
||||
# 2. 保存结果到向量数据库
|
||||
if image_description_results_file:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
version: "3"
|
||||
services:
|
||||
trinity_mixi:
|
||||
container_name: trinity_mixi
|
||||
build: .
|
||||
image: aidlabzcr/trinity_mixi:latest
|
||||
# build: .
|
||||
volumes:
|
||||
- ./trinity_client_mixi:/trinity
|
||||
- ./:/app
|
||||
ports:
|
||||
- "10100:4562"
|
||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Reference in New Issue
Block a user