From 07e72c1ee15ff1fe4742c3be7040efd2e49a42d1 Mon Sep 17 00:00:00 2001 From: shahaibo <1023316923@qq.com> Date: Thu, 24 Oct 2024 15:59:36 +0800 Subject: [PATCH] =?UTF-8?q?TASK:lookbook=E4=B8=8A=E4=BC=A0=EF=BC=8C?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_lookbooks_query.py | 26 +++++++++ app/api/api_route.py | 4 +- app/service/lookbooks/query_service.py | 74 ++++++++++++++++++++++++++ app/service/lookbooks/service.py | 10 +++- 4 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 app/api/api_lookbooks_query.py create mode 100644 app/service/lookbooks/query_service.py diff --git a/app/api/api_lookbooks_query.py b/app/api/api_lookbooks_query.py new file mode 100644 index 0000000..4cd14ad --- /dev/null +++ b/app/api/api_lookbooks_query.py @@ -0,0 +1,26 @@ +from fastapi import APIRouter, Query +from typing import Optional +from app.service.lookbooks.query_service import query_lookbooks_service # 引入业务逻辑 + +router = APIRouter() + +@router.get("/query-lookbooks") +async def query_lookbooks( + tag: Optional[str] = Query(None, description="Tag to filter lookbooks"), + year: Optional[str] = Query(None, description="Year to filter lookbooks"), + n_results: int = Query(10, description="Number of results to return") +): + """ + 查询向量数据库,支持按 tag 和 year 查询 + :param tag: 查询过滤的标签 + :param year: 查询过滤的年份 + :param n_results: 返回的结果数量 + :return: 查询结果 + """ + try: + # 调用业务逻辑层的查询服务 + result_list = await query_lookbooks_service(tag, year, n_results) + return {"status": "success", "data": result_list} + + except Exception as e: + return {"status": "error", "message": str(e)} diff --git a/app/api/api_route.py b/app/api/api_route.py index 1b7cf93..db2dbaf 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,6 +1,7 @@ from fastapi import APIRouter -from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks +from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks, \ + api_lookbooks_query router = APIRouter() @@ -9,3 +10,4 @@ router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix router.include_router(api_attribute.router, tags=["attribute"], prefix="/api") router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api") router.include_router(api_process_lookbooks.router, tags=["process-lookbooks"], prefix="/api") +router.include_router(api_lookbooks_query.router, tags=["query-lookbooks"], prefix="/api") diff --git a/app/service/lookbooks/query_service.py b/app/service/lookbooks/query_service.py new file mode 100644 index 0000000..7c557b2 --- /dev/null +++ b/app/service/lookbooks/query_service.py @@ -0,0 +1,74 @@ +from app.service.lookbooks.config.config import DOCUMENT_COLLECTION # 引入向量数据库 +from typing import Optional + +async def query_lookbooks_service(tag: Optional[str], year: Optional[str], n_results: int): + """ + 查询向量数据库,支持按 tag 和 year 过滤 + :param tag: 查询过滤的标签 + :param year: 查询过滤的年份 + :param n_results: 返回的结果数量 + :return: 查询结果列表 + """ + try: + # 选择一个主要过滤条件进行初步查询 + primary_filter = {} + if tag: + primary_filter['tag'] = tag + elif year: + primary_filter['year'] = year + else: + primary_filter = None + + # 如果没有任何条件,直接返回空列表 + if not primary_filter: + return [] + + # 使用主过滤条件进行查询 + query_result = DOCUMENT_COLLECTION.get(where=primary_filter) + + # 打印 query_result 以检查数据结构是否符合预期 + print(query_result) # 或者使用 logger 记录以调试 + + # 检查 query_result 的结构 + if not isinstance(query_result, dict) or 'documents' not in query_result: + raise ValueError("Expected query_result to be a dict containing a 'documents' key") + + documents = query_result['documents'] + + # 确保 documents 是一个列表 + if not isinstance(documents, list): + raise ValueError("Expected 'documents' to be a list") + + # 对文档进行进一步过滤 + filtered_results = [] + for item in documents: + # 检查每个 item 是否是字典,如果不是字典,可能是简单文本描述 + if isinstance(item, dict): + metadata = item.get('metadata', {}) + if (not year or metadata.get('year') == year) and (not tag or metadata.get('tag') == tag): + filtered_results.append(item) + elif isinstance(item, str): + # 如果 item 是字符串,假设它是描述文本,构造默认的 metadata + filtered_results.append({ + "text": item, + "metadata": {} + }) + else: + # 如果是其他数据类型,打印警告并跳过 + print(f"Unexpected item type in documents: {type(item)}") + + # 限制结果数量 + limited_results = filtered_results[:n_results] + + # 格式化结果 + result_list = [] + for item in limited_results: + result_list.append({ + "description": item.get('text', ""), + "metadata": item.get('metadata', {}) + }) + + return result_list + + except Exception as e: + raise e # 抛出异常,让接口层处理 diff --git a/app/service/lookbooks/service.py b/app/service/lookbooks/service.py index 77fc518..e81652b 100644 --- a/app/service/lookbooks/service.py +++ b/app/service/lookbooks/service.py @@ -45,6 +45,9 @@ def create_image_batch_requests( # 预处理 prompt:移除多余的空白和换行符 prompt = ' '.join(prompt.split()) + # 创建目录(如果目录不存在) + os.makedirs(output_path, exist_ok=True) + completed_id = [] if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")): with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f: @@ -141,7 +144,8 @@ async def process_lookbook_task(lookbook_list, tag, year): try: for look_book_path in tqdm.tqdm(lookbook_list): lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0] - output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name) + output_dir = os.path.join("fashion_documents", "lookbook", "images", lookbook_name) + # output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name) os.makedirs(output_dir, exist_ok=True) if not os.listdir(output_dir): from unstructured.partition.pdf import partition_pdf @@ -159,8 +163,10 @@ async def process_lookbook_task(lookbook_list, tag, year): current_images = os.listdir(output_dir) image_list.extend([os.path.join(output_dir, x) for x in current_images]) + output_path = os.path.join("fashion_documents", "lookbook", "results") + # 1. 处理图片并生成批量请求 - image_description_results_file = create_image_batch_requests(image_list, "fashion_documents/lookbook/results") + image_description_results_file = create_image_batch_requests(image_list, output_path) # 2. 保存结果到向量数据库 if image_description_results_file: