From 6cb32d11a8c4a1f333bad099486b9e119cf834f9 Mon Sep 17 00:00:00 2001 From: zchengrong <124802516+zchengrong@users.noreply.github.com> Date: Thu, 15 May 2025 14:49:33 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20=20=E9=A1=B9=E7=9B=AE=E4=BF=A1=E6=81=AF=E6=8F=90?= =?UTF-8?q?=E5=8F=96/=E7=94=9F=E6=88=90=E6=8E=A5=E5=8F=A3=20fix=EF=BC=88?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs=EF=BC=88=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88?= =?UTF-8?q?=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_extraction_project_info.py | 33 +++++++++++++ app/api/api_route.py | 5 +- app/schemas/project_info_extraction.py | 5 ++ .../service_generate_brand_info.py | 47 +++++++++++++++++++ 4 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 app/api/api_extraction_project_info.py create mode 100644 app/schemas/project_info_extraction.py create mode 100644 app/service/project_info_extraction/service_generate_brand_info.py diff --git a/app/api/api_extraction_project_info.py b/app/api/api_extraction_project_info.py new file mode 100644 index 0000000..51eb473 --- /dev/null +++ b/app/api/api_extraction_project_info.py @@ -0,0 +1,33 @@ +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.project_info_extraction import ProjectInfoExtractionModel +from app.schemas.response_template import ResponseModel +from app.service.project_info_extraction.service_generate_brand_info import ProjectInfoExtraction + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/extraction_project_info") +def extraction_project_info(request_data: ProjectInfoExtractionModel): + """ + 通过prompt 提取project_name,role ,gender ,style。 + 创建一个具有以下参数的请求体: + - **prompt**: + + 示例参数: + { + "prompt": "海边派对主题的系列设计" + } + """ + try: + logger.info(f"extraction_project_info request item is : @@@@@@:{request_data}") + service = ProjectInfoExtraction(request_data) + data = service.get_result() + logger.info(f"extraction_project_info response @@@@@@:{data}") + except Exception as e: + logger.warning(f"extraction_project_info Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_route.py b/app/api/api_route.py index b82c942..d85cbce 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -5,16 +5,16 @@ from app.api import api_attribute_retrieve, api_query_image from app.api import api_brand_dna from app.api import api_brighten from app.api import api_chat_robot +from app.api import api_clothing_seg from app.api import api_design from app.api import api_design_pre_processing +from app.api import api_extraction_project_info from app.api import api_generate_image from app.api import api_image2sketch from app.api import api_mannequins_edit from app.api import api_pose_transform from app.api import api_prompt_generation -from app.api import api_clothing_seg from app.api import api_super_resolution -from app.api import api_recommendation from app.api import api_test router = APIRouter() @@ -36,3 +36,4 @@ router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api") router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api") router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api") +router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api") diff --git a/app/schemas/project_info_extraction.py b/app/schemas/project_info_extraction.py new file mode 100644 index 0000000..90def8b --- /dev/null +++ b/app/schemas/project_info_extraction.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class ProjectInfoExtractionModel(BaseModel): + prompt: str diff --git a/app/service/project_info_extraction/service_generate_brand_info.py b/app/service/project_info_extraction/service_generate_brand_info.py new file mode 100644 index 0000000..8ee7bcd --- /dev/null +++ b/app/service/project_info_extraction/service_generate_brand_info.py @@ -0,0 +1,47 @@ +from langchain.output_parsers import ResponseSchema, StructuredOutputParser +from langchain_community.chat_models import ChatTongyi +from langchain_core.prompts import PromptTemplate + +from app.schemas.project_info_extraction import ProjectInfoExtractionModel + + +class ProjectInfoExtraction: + def __init__(self, request_data): + # llm generate brand info init + self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") + + self.response_schemas = [ + ResponseSchema(name="project_name", description="project name."), + ResponseSchema(name="role", description="The target role of the project."), + ResponseSchema(name="gender", description="The gender targeted by the project."), + ResponseSchema(name="style", description="Project style.") + ] + self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas) + self.format_instructions = self.output_parser.get_format_instructions() + self.prompt = PromptTemplate( + template="你是一个时装品牌的设计师助理。根据用户输入提取出project_name,role ,gender ,style ." + "gender部分请用以下:menswear,womenswear,childrenwear,如果全部都适用即all." + "如果没有以上内容,需要你根据用户输入随意发挥.\n{format_instructions}\n{question}", + input_variables=["question"], + partial_variables={"format_instructions": self.format_instructions} + ) + self._input = self.prompt.format_prompt(question=request_data.prompt) + + self.result_data = {} + + def get_result(self): + self.llm_extraction_project_info() + return self.result_data + + def llm_extraction_project_info(self): + output = self.model(self._input.to_messages()) + project_info = self.output_parser.parse(output.content) + self.result_data = project_info + + +if __name__ == '__main__': + request_data = ProjectInfoExtractionModel( + prompt="海边派对主题的系列设计" + ) + service = ProjectInfoExtraction(request_data) + print(service.get_result())