Files
AiDA_Python/app/service/fashion_agent/main_agent.py

145 lines
4.9 KiB
Python
Raw Normal View History

2026-06-15 14:48:17 +08:00
import sys
from pathlib import Path
from typing import Annotated, Required, TypedDict
from deepagents import CompiledSubAgent, create_deep_agent
from langchain.agents import create_agent
from langchain.tools import tool
from langchain_core.messages import AnyMessage, HumanMessage
from langchain_qwq import ChatQwen
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from app.service.fashion_agent.graph_node.design_graph.graph import build_design_graph
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph
from app.service.fashion_agent.graph_node.node_tools.generate_logo import generate_logo_tool
from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph
from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool
from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
from app.service.fashion_agent.graph_node.trending_graph.trending_graph import build_trending_graph
from app.service.fashion_agent.graph_node.explorer_graph.graph import build_explorer_graph
from app.service.fashion_agent.init_llm import build_llm
print_graph = build_print_graph()
logo_graph = build_logo_graph()
sketch_graph = build_sketch_graph()
design_graph = build_design_graph()
trending_graph = build_trending_graph()
explorer_graph = build_explorer_graph()
class MainState(TypedDict):
# 消息
messages: Required[Annotated[list[AnyMessage], add_messages]]
# 模块控制
call_design: bool = False
call_print: bool = False
call_logo: bool = False
call_sketch: bool = False
call_design: bool = False
call_trending: bool = False
call_explor: bool = False
# design参数
design_request_data: dict = {}
# 模块需求标志
print_need_prompt_generation: bool = False
sketch_need_prompt_generation: bool = False
# 公共参数
role: str = ""
gender: str = ""
style: str = ""
# print模块结果
print_img_urls: list[str] = []
tools = [explor_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
def route_node(state: MainState) -> str:
"""根据标志决定走哪条路径"""
if state.get("call_print"):
return "direct_print"
if state.get("call_logo"):
return "direct_logo"
if state.get("call_sketch"):
return "direct_sketch"
if state.get("call_design"):
return "direct_design"
if state.get("call_trending"):
return "direct_trending"
if state.get("call_explor"):
return "direct_explor"
return "llm_agent"
def build_main_graph(enable_thinking: bool = False):
llm = build_llm(enable_thinking=enable_thinking)
chat_agent = create_agent(
model=llm, tools=tools, state_schema=MainState, system_prompt="你是一个专业的服装设计助手。根据用户需求,调用合适的工具完成任务."
)
"""构建主图"""
workflow = StateGraph(MainState)
# 添加节点
workflow.add_node("llm_agent", chat_agent)
workflow.add_node("direct_print", print_graph)
workflow.add_node("direct_logo", logo_graph)
workflow.add_node("direct_sketch", sketch_graph)
workflow.add_node("direct_design", design_graph)
workflow.add_node("direct_trending", trending_graph)
workflow.add_node("direct_explor", explorer_graph)
# 条件分支
workflow.add_conditional_edges(
START,
route_node,
{
"llm_agent": "llm_agent",
"direct_print": "direct_print",
"direct_logo": "direct_logo",
"direct_sketch": "direct_sketch",
"direct_design": "direct_design",
"direct_trending": "direct_trending",
"direct_explor": "direct_explor",
},
)
# 所有路径都到 END
workflow.add_edge("llm_agent", END)
workflow.add_edge("direct_print", END)
workflow.add_edge("direct_logo", END)
workflow.add_edge("direct_sketch", END)
workflow.add_edge("direct_design", END)
workflow.add_edge("direct_trending", END)
workflow.add_edge("direct_explor", END)
return workflow.compile()
agent = build_main_graph()
if __name__ == "__main__":
import asyncio
async def test_direct():
# 直接调用 sketch跳过 LLM
result = await agent.ainvoke(
{
"messages": [HumanMessage(content="我想设计衬衫,带有猫咪的图案")],
"call_sketch": True,
"sketch_need_prompt_generation": False,
}
)
print("=== 直接调用 sketch ===")
print(result["messages"][-1].content)
asyncio.run(test_direct())