diff --git a/app/api/api_extraction_project_info.py b/app/api/api_extraction_project_info.py index 51eb473..ad55552 100644 --- a/app/api/api_extraction_project_info.py +++ b/app/api/api_extraction_project_info.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException from app.schemas.project_info_extraction import ProjectInfoExtractionModel from app.schemas.response_template import ResponseModel -from app.service.project_info_extraction.service_generate_brand_info import ProjectInfoExtraction +from app.service.project_info_extraction.service import ProjectInfoExtraction router = APIRouter() logger = logging.getLogger() diff --git a/app/service/project_info_extraction/service.py b/app/service/project_info_extraction/service.py new file mode 100644 index 0000000..40d59ba --- /dev/null +++ b/app/service/project_info_extraction/service.py @@ -0,0 +1,61 @@ +from langchain.output_parsers import ResponseSchema, StructuredOutputParser +from langchain_community.chat_models import ChatTongyi +from langchain_core.prompts import PromptTemplate + +from app.schemas.project_info_extraction import ProjectInfoExtractionModel + +style = ['NEW_CHINESE', 'COUNTRY_STYLE', 'FUTURISM', 'MINIMALISM', 'LOLITA', 'Y2K', 'BUSINESS', 'MERLAD', + 'OUTDOOR_FUNCTIONAL', 'ROCK', 'DOPAMINE', 'GOTHIC', 'POST_APOCALYPTIC', 'ROMANTIC', 'WABI_SABI'] +position = ['Overall', 'Tops', 'Bottoms', 'Outwear', 'Blouse', 'Dress', 'Trousers', 'Skirt'] +gender = ['Female', 'Male'] +age_group = ['Adult', 'Child'] +process = ['SERIES_DESIGN', 'SINGLE_DESIGN'] + + +class ProjectInfoExtraction: + def __init__(self, request_data): + # llm generate brand info init + self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") + + self.response_schemas = [ + ResponseSchema(name="project_name", description="项目的名称."), + ResponseSchema(name="process", description="项目的类型,单品还是系列."), + ResponseSchema(name="ageGroup", description="项目设计服装的受众对象."), + ResponseSchema(name="gender", description="项目设计服装的受众性别."), + ResponseSchema(name="position", description="项目单品设计的部位."), + ResponseSchema(name="style", description="项目的设计风格.") + ] + self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas) + self.format_instructions = self.output_parser.get_format_instructions() + self.prompt = PromptTemplate( + template="你是一个时装品牌的设计师助理。根据用户输入提取出" + "[project_name] : 项目的名称," + f"[process] : 项目的类型,从{process}选择." + f"[ageGroup] : 服装的受众,从{age_group}选择." + f"[gender] : 服装的适用性别,从{gender}选择" + f"[position] : single_design的部位,如果[process]是SINGLE_DESIGN,从{position}中选择,如果[process]是SERIES_DESIGN,这项为空" + f"[style] : 设计的风格,从{style}中选择" + ".\n{format_instructions}\n{question}", + input_variables=["question"], + partial_variables={"format_instructions": self.format_instructions} + ) + self._input = self.prompt.format_prompt(question=request_data.prompt) + + self.result_data = {} + + def get_result(self): + self.llm_extraction_project_info() + return self.result_data + + def llm_extraction_project_info(self): + output = self.model(self._input.to_messages()) + project_info = self.output_parser.parse(output.content) + self.result_data = project_info + + +if __name__ == '__main__': + request_data = ProjectInfoExtractionModel( + prompt="海边派对主题的衬衫设计" + ) + service = ProjectInfoExtraction(request_data) + print(service.get_result()) diff --git a/app/service/project_info_extraction/service_generate_brand_info.py b/app/service/project_info_extraction/service_generate_brand_info.py deleted file mode 100644 index 8ee7bcd..0000000 --- a/app/service/project_info_extraction/service_generate_brand_info.py +++ /dev/null @@ -1,47 +0,0 @@ -from langchain.output_parsers import ResponseSchema, StructuredOutputParser -from langchain_community.chat_models import ChatTongyi -from langchain_core.prompts import PromptTemplate - -from app.schemas.project_info_extraction import ProjectInfoExtractionModel - - -class ProjectInfoExtraction: - def __init__(self, request_data): - # llm generate brand info init - self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") - - self.response_schemas = [ - ResponseSchema(name="project_name", description="project name."), - ResponseSchema(name="role", description="The target role of the project."), - ResponseSchema(name="gender", description="The gender targeted by the project."), - ResponseSchema(name="style", description="Project style.") - ] - self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas) - self.format_instructions = self.output_parser.get_format_instructions() - self.prompt = PromptTemplate( - template="你是一个时装品牌的设计师助理。根据用户输入提取出project_name,role ,gender ,style ." - "gender部分请用以下:menswear,womenswear,childrenwear,如果全部都适用即all." - "如果没有以上内容,需要你根据用户输入随意发挥.\n{format_instructions}\n{question}", - input_variables=["question"], - partial_variables={"format_instructions": self.format_instructions} - ) - self._input = self.prompt.format_prompt(question=request_data.prompt) - - self.result_data = {} - - def get_result(self): - self.llm_extraction_project_info() - return self.result_data - - def llm_extraction_project_info(self): - output = self.model(self._input.to_messages()) - project_info = self.output_parser.parse(output.content) - self.result_data = project_info - - -if __name__ == '__main__': - request_data = ProjectInfoExtractionModel( - prompt="海边派对主题的系列设计" - ) - service = ProjectInfoExtraction(request_data) - print(service.get_result())