from typing import Annotated, Required, TypedDict from langchain.agents import create_agent from langchain_core.messages import AnyMessage from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from app.service.fashion_agent.graph_node.design_graph.graph import build_design_graph from app.service.fashion_agent.graph_node.explore_graph.tools import explore_tool from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool from app.service.fashion_agent.graph_node.trending_graph.trending_graph import build_trending_graph from app.service.fashion_agent.graph_node.explore_graph.graph import build_explore_graph from app.service.fashion_agent.init_llm import build_llm print_graph = build_print_graph() logo_graph = build_logo_graph() sketch_graph = build_sketch_graph() design_graph = build_design_graph() trending_graph = build_trending_graph() explore_graph = build_explore_graph() class MainState(TypedDict): # 消息 messages: Required[Annotated[list[AnyMessage], add_messages]] # 上传图片 input_images: list[str] = [] # 模块控制 call_design: bool = False call_print: bool = False call_logo: bool = False call_sketch: bool = False call_design: bool = False call_trending: bool = False call_explore: bool = False # design参数 design_request_data: dict = {} # 模块需求标志 print_need_prompt_generation: bool = False sketch_need_prompt_generation: bool = False # 公共参数 role: str = "" gender: str = "" style: str = "" # print模块结果 print_img_urls: list[str] = [] tools = [explore_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool] def route_node(state: MainState) -> str: """根据标志决定走哪条路径""" if state.get("call_print"): return "direct_print" if state.get("call_logo"): return "direct_logo" if state.get("call_sketch"): return "direct_sketch" if state.get("call_design"): return "direct_design" if state.get("call_trending"): return "direct_trending" if state.get("call_explore"): return "direct_explore" return "llm_agent" async def build_main_graph(enable_thinking: bool = False, checkpointer=None): llm = build_llm(enable_thinking=enable_thinking) chat_agent = create_agent( model=llm, tools=tools, state_schema=MainState, system_prompt="你是一个专业的服装设计助手。根据用户需求,调用合适的工具完成任务." ) """构建主图""" workflow = StateGraph(MainState) # 添加节点 workflow.add_node("llm_agent", chat_agent) workflow.add_node("direct_print", print_graph) workflow.add_node("direct_logo", logo_graph) workflow.add_node("direct_sketch", sketch_graph) workflow.add_node("direct_design", design_graph) workflow.add_node("direct_trending", trending_graph) workflow.add_node("direct_explore", explore_graph) # 条件分支 workflow.add_conditional_edges( START, route_node, { "llm_agent": "llm_agent", "direct_print": "direct_print", "direct_logo": "direct_logo", "direct_sketch": "direct_sketch", "direct_design": "direct_design", "direct_trending": "direct_trending", "direct_explore": "direct_explore", }, ) # 所有路径都到 END workflow.add_edge("llm_agent", END) workflow.add_edge("direct_print", END) workflow.add_edge("direct_logo", END) workflow.add_edge("direct_sketch", END) workflow.add_edge("direct_design", END) workflow.add_edge("direct_trending", END) workflow.add_edge("direct_explore", END) graph = workflow.compile(checkpointer=checkpointer) return graph