2026-06-15 14:48:17 +08:00
|
|
|
from typing import Annotated, Required, TypedDict
|
|
|
|
|
from langchain.agents import create_agent
|
2026-06-17 11:56:53 +08:00
|
|
|
from langchain_core.messages import AnyMessage
|
2026-06-15 14:48:17 +08:00
|
|
|
from langgraph.graph import END, START, StateGraph
|
|
|
|
|
from langgraph.graph.message import add_messages
|
2026-06-17 11:56:53 +08:00
|
|
|
|
2026-06-15 14:48:17 +08:00
|
|
|
from app.service.fashion_agent.graph_node.design_graph.graph import build_design_graph
|
2026-06-17 11:56:53 +08:00
|
|
|
from app.service.fashion_agent.graph_node.explore_graph.tools import explore_tool
|
2026-06-15 14:48:17 +08:00
|
|
|
from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph
|
2026-06-15 17:10:04 +08:00
|
|
|
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
|
2026-06-15 14:48:17 +08:00
|
|
|
from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph
|
|
|
|
|
from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool
|
|
|
|
|
from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph
|
|
|
|
|
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
|
|
|
|
|
from app.service.fashion_agent.graph_node.trending_graph.trending_graph import build_trending_graph
|
2026-06-17 11:56:53 +08:00
|
|
|
from app.service.fashion_agent.graph_node.explore_graph.graph import build_explore_graph
|
2026-06-15 14:48:17 +08:00
|
|
|
from app.service.fashion_agent.init_llm import build_llm
|
|
|
|
|
|
|
|
|
|
print_graph = build_print_graph()
|
|
|
|
|
logo_graph = build_logo_graph()
|
|
|
|
|
sketch_graph = build_sketch_graph()
|
|
|
|
|
design_graph = build_design_graph()
|
|
|
|
|
trending_graph = build_trending_graph()
|
2026-06-17 11:56:53 +08:00
|
|
|
explore_graph = build_explore_graph()
|
2026-06-15 14:48:17 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainState(TypedDict):
|
|
|
|
|
# 消息
|
|
|
|
|
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
|
|
|
|
|
2026-06-17 11:56:53 +08:00
|
|
|
# 上传图片
|
|
|
|
|
input_images: list[str] = []
|
|
|
|
|
|
2026-06-15 14:48:17 +08:00
|
|
|
# 模块控制
|
|
|
|
|
call_design: bool = False
|
|
|
|
|
call_print: bool = False
|
|
|
|
|
call_logo: bool = False
|
|
|
|
|
call_sketch: bool = False
|
|
|
|
|
call_design: bool = False
|
|
|
|
|
call_trending: bool = False
|
2026-06-17 11:56:53 +08:00
|
|
|
call_explore: bool = False
|
2026-06-15 14:48:17 +08:00
|
|
|
|
|
|
|
|
# design参数
|
|
|
|
|
design_request_data: dict = {}
|
|
|
|
|
|
|
|
|
|
# 模块需求标志
|
|
|
|
|
print_need_prompt_generation: bool = False
|
|
|
|
|
sketch_need_prompt_generation: bool = False
|
|
|
|
|
|
|
|
|
|
# 公共参数
|
|
|
|
|
role: str = ""
|
|
|
|
|
gender: str = ""
|
|
|
|
|
style: str = ""
|
|
|
|
|
|
|
|
|
|
# print模块结果
|
|
|
|
|
print_img_urls: list[str] = []
|
|
|
|
|
|
|
|
|
|
|
2026-06-17 11:56:53 +08:00
|
|
|
tools = [explore_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
|
2026-06-15 14:48:17 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def route_node(state: MainState) -> str:
|
|
|
|
|
"""根据标志决定走哪条路径"""
|
|
|
|
|
if state.get("call_print"):
|
|
|
|
|
return "direct_print"
|
|
|
|
|
if state.get("call_logo"):
|
|
|
|
|
return "direct_logo"
|
|
|
|
|
if state.get("call_sketch"):
|
|
|
|
|
return "direct_sketch"
|
|
|
|
|
if state.get("call_design"):
|
|
|
|
|
return "direct_design"
|
|
|
|
|
if state.get("call_trending"):
|
|
|
|
|
return "direct_trending"
|
2026-06-17 11:56:53 +08:00
|
|
|
if state.get("call_explore"):
|
|
|
|
|
return "direct_explore"
|
2026-06-15 14:48:17 +08:00
|
|
|
return "llm_agent"
|
|
|
|
|
|
|
|
|
|
|
2026-06-17 11:56:53 +08:00
|
|
|
async def build_main_graph(enable_thinking: bool = False, checkpointer=None):
|
2026-06-15 14:48:17 +08:00
|
|
|
llm = build_llm(enable_thinking=enable_thinking)
|
|
|
|
|
chat_agent = create_agent(
|
|
|
|
|
model=llm, tools=tools, state_schema=MainState, system_prompt="你是一个专业的服装设计助手。根据用户需求,调用合适的工具完成任务."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
"""构建主图"""
|
|
|
|
|
workflow = StateGraph(MainState)
|
|
|
|
|
|
|
|
|
|
# 添加节点
|
|
|
|
|
workflow.add_node("llm_agent", chat_agent)
|
|
|
|
|
workflow.add_node("direct_print", print_graph)
|
|
|
|
|
workflow.add_node("direct_logo", logo_graph)
|
|
|
|
|
workflow.add_node("direct_sketch", sketch_graph)
|
|
|
|
|
workflow.add_node("direct_design", design_graph)
|
|
|
|
|
workflow.add_node("direct_trending", trending_graph)
|
2026-06-17 11:56:53 +08:00
|
|
|
workflow.add_node("direct_explore", explore_graph)
|
2026-06-15 14:48:17 +08:00
|
|
|
|
|
|
|
|
# 条件分支
|
|
|
|
|
workflow.add_conditional_edges(
|
|
|
|
|
START,
|
|
|
|
|
route_node,
|
|
|
|
|
{
|
|
|
|
|
"llm_agent": "llm_agent",
|
|
|
|
|
"direct_print": "direct_print",
|
|
|
|
|
"direct_logo": "direct_logo",
|
|
|
|
|
"direct_sketch": "direct_sketch",
|
|
|
|
|
"direct_design": "direct_design",
|
|
|
|
|
"direct_trending": "direct_trending",
|
2026-06-17 11:56:53 +08:00
|
|
|
"direct_explore": "direct_explore",
|
2026-06-15 14:48:17 +08:00
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 所有路径都到 END
|
|
|
|
|
workflow.add_edge("llm_agent", END)
|
|
|
|
|
workflow.add_edge("direct_print", END)
|
|
|
|
|
workflow.add_edge("direct_logo", END)
|
|
|
|
|
workflow.add_edge("direct_sketch", END)
|
|
|
|
|
workflow.add_edge("direct_design", END)
|
|
|
|
|
workflow.add_edge("direct_trending", END)
|
2026-06-17 11:56:53 +08:00
|
|
|
workflow.add_edge("direct_explore", END)
|
2026-06-15 14:48:17 +08:00
|
|
|
|
2026-06-17 11:56:53 +08:00
|
|
|
graph = workflow.compile(checkpointer=checkpointer)
|
|
|
|
|
return graph
|