feat(新功能): 项目信息提取/生成接口
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
@@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException
|
|||||||
|
|
||||||
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
|
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.project_info_extraction.service_generate_brand_info import ProjectInfoExtraction
|
from app.service.project_info_extraction.service import ProjectInfoExtraction
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|||||||
61
app/service/project_info_extraction/service.py
Normal file
61
app/service/project_info_extraction/service.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||||||
|
from langchain_community.chat_models import ChatTongyi
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
|
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
|
||||||
|
|
||||||
|
style = ['NEW_CHINESE', 'COUNTRY_STYLE', 'FUTURISM', 'MINIMALISM', 'LOLITA', 'Y2K', 'BUSINESS', 'MERLAD',
|
||||||
|
'OUTDOOR_FUNCTIONAL', 'ROCK', 'DOPAMINE', 'GOTHIC', 'POST_APOCALYPTIC', 'ROMANTIC', 'WABI_SABI']
|
||||||
|
position = ['Overall', 'Tops', 'Bottoms', 'Outwear', 'Blouse', 'Dress', 'Trousers', 'Skirt']
|
||||||
|
gender = ['Female', 'Male']
|
||||||
|
age_group = ['Adult', 'Child']
|
||||||
|
process = ['SERIES_DESIGN', 'SINGLE_DESIGN']
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectInfoExtraction:
|
||||||
|
def __init__(self, request_data):
|
||||||
|
# llm generate brand info init
|
||||||
|
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||||
|
|
||||||
|
self.response_schemas = [
|
||||||
|
ResponseSchema(name="project_name", description="项目的名称."),
|
||||||
|
ResponseSchema(name="process", description="项目的类型,单品还是系列."),
|
||||||
|
ResponseSchema(name="ageGroup", description="项目设计服装的受众对象."),
|
||||||
|
ResponseSchema(name="gender", description="项目设计服装的受众性别."),
|
||||||
|
ResponseSchema(name="position", description="项目单品设计的部位."),
|
||||||
|
ResponseSchema(name="style", description="项目的设计风格.")
|
||||||
|
]
|
||||||
|
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
|
||||||
|
self.format_instructions = self.output_parser.get_format_instructions()
|
||||||
|
self.prompt = PromptTemplate(
|
||||||
|
template="你是一个时装品牌的设计师助理。根据用户输入提取出"
|
||||||
|
"[project_name] : 项目的名称,"
|
||||||
|
f"[process] : 项目的类型,从{process}选择."
|
||||||
|
f"[ageGroup] : 服装的受众,从{age_group}选择."
|
||||||
|
f"[gender] : 服装的适用性别,从{gender}选择"
|
||||||
|
f"[position] : single_design的部位,如果[process]是SINGLE_DESIGN,从{position}中选择,如果[process]是SERIES_DESIGN,这项为空"
|
||||||
|
f"[style] : 设计的风格,从{style}中选择"
|
||||||
|
".\n{format_instructions}\n{question}",
|
||||||
|
input_variables=["question"],
|
||||||
|
partial_variables={"format_instructions": self.format_instructions}
|
||||||
|
)
|
||||||
|
self._input = self.prompt.format_prompt(question=request_data.prompt)
|
||||||
|
|
||||||
|
self.result_data = {}
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
self.llm_extraction_project_info()
|
||||||
|
return self.result_data
|
||||||
|
|
||||||
|
def llm_extraction_project_info(self):
|
||||||
|
output = self.model(self._input.to_messages())
|
||||||
|
project_info = self.output_parser.parse(output.content)
|
||||||
|
self.result_data = project_info
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
request_data = ProjectInfoExtractionModel(
|
||||||
|
prompt="海边派对主题的衬衫设计"
|
||||||
|
)
|
||||||
|
service = ProjectInfoExtraction(request_data)
|
||||||
|
print(service.get_result())
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
|
||||||
from langchain_community.chat_models import ChatTongyi
|
|
||||||
from langchain_core.prompts import PromptTemplate
|
|
||||||
|
|
||||||
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectInfoExtraction:
|
|
||||||
def __init__(self, request_data):
|
|
||||||
# llm generate brand info init
|
|
||||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
|
||||||
|
|
||||||
self.response_schemas = [
|
|
||||||
ResponseSchema(name="project_name", description="project name."),
|
|
||||||
ResponseSchema(name="role", description="The target role of the project."),
|
|
||||||
ResponseSchema(name="gender", description="The gender targeted by the project."),
|
|
||||||
ResponseSchema(name="style", description="Project style.")
|
|
||||||
]
|
|
||||||
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
|
|
||||||
self.format_instructions = self.output_parser.get_format_instructions()
|
|
||||||
self.prompt = PromptTemplate(
|
|
||||||
template="你是一个时装品牌的设计师助理。根据用户输入提取出project_name,role ,gender ,style ."
|
|
||||||
"gender部分请用以下:menswear,womenswear,childrenwear,如果全部都适用即all."
|
|
||||||
"如果没有以上内容,需要你根据用户输入随意发挥.\n{format_instructions}\n{question}",
|
|
||||||
input_variables=["question"],
|
|
||||||
partial_variables={"format_instructions": self.format_instructions}
|
|
||||||
)
|
|
||||||
self._input = self.prompt.format_prompt(question=request_data.prompt)
|
|
||||||
|
|
||||||
self.result_data = {}
|
|
||||||
|
|
||||||
def get_result(self):
|
|
||||||
self.llm_extraction_project_info()
|
|
||||||
return self.result_data
|
|
||||||
|
|
||||||
def llm_extraction_project_info(self):
|
|
||||||
output = self.model(self._input.to_messages())
|
|
||||||
project_info = self.output_parser.parse(output.content)
|
|
||||||
self.result_data = project_info
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
request_data = ProjectInfoExtractionModel(
|
|
||||||
prompt="海边派对主题的系列设计"
|
|
||||||
)
|
|
||||||
service = ProjectInfoExtraction(request_data)
|
|
||||||
print(service.get_result())
|
|
||||||
Reference in New Issue
Block a user