Files
AiDA_Python/app/service/fashion_agent/main_agent.py

145 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
from pathlib import Path
from typing import Annotated, Required, TypedDict
from deepagents import CompiledSubAgent, create_deep_agent
from langchain.agents import create_agent
from langchain.tools import tool
from langchain_core.messages import AnyMessage, HumanMessage
from langchain_qwq import ChatQwen
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from app.service.fashion_agent.graph_node.design_graph.graph import build_design_graph
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph
from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool
from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
from app.service.fashion_agent.graph_node.trending_graph.trending_graph import build_trending_graph
from app.service.fashion_agent.graph_node.explorer_graph.graph import build_explorer_graph
from app.service.fashion_agent.init_llm import build_llm
print_graph = build_print_graph()
logo_graph = build_logo_graph()
sketch_graph = build_sketch_graph()
design_graph = build_design_graph()
trending_graph = build_trending_graph()
explorer_graph = build_explorer_graph()
class MainState(TypedDict):
# 消息
messages: Required[Annotated[list[AnyMessage], add_messages]]
# 模块控制
call_design: bool = False
call_print: bool = False
call_logo: bool = False
call_sketch: bool = False
call_design: bool = False
call_trending: bool = False
call_explor: bool = False
# design参数
design_request_data: dict = {}
# 模块需求标志
print_need_prompt_generation: bool = False
sketch_need_prompt_generation: bool = False
# 公共参数
role: str = ""
gender: str = ""
style: str = ""
# print模块结果
print_img_urls: list[str] = []
tools = [explor_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
def route_node(state: MainState) -> str:
"""根据标志决定走哪条路径"""
if state.get("call_print"):
return "direct_print"
if state.get("call_logo"):
return "direct_logo"
if state.get("call_sketch"):
return "direct_sketch"
if state.get("call_design"):
return "direct_design"
if state.get("call_trending"):
return "direct_trending"
if state.get("call_explor"):
return "direct_explor"
return "llm_agent"
def build_main_graph(enable_thinking: bool = False):
llm = build_llm(enable_thinking=enable_thinking)
chat_agent = create_agent(
model=llm, tools=tools, state_schema=MainState, system_prompt="你是一个专业的服装设计助手。根据用户需求,调用合适的工具完成任务."
)
"""构建主图"""
workflow = StateGraph(MainState)
# 添加节点
workflow.add_node("llm_agent", chat_agent)
workflow.add_node("direct_print", print_graph)
workflow.add_node("direct_logo", logo_graph)
workflow.add_node("direct_sketch", sketch_graph)
workflow.add_node("direct_design", design_graph)
workflow.add_node("direct_trending", trending_graph)
workflow.add_node("direct_explor", explorer_graph)
# 条件分支
workflow.add_conditional_edges(
START,
route_node,
{
"llm_agent": "llm_agent",
"direct_print": "direct_print",
"direct_logo": "direct_logo",
"direct_sketch": "direct_sketch",
"direct_design": "direct_design",
"direct_trending": "direct_trending",
"direct_explor": "direct_explor",
},
)
# 所有路径都到 END
workflow.add_edge("llm_agent", END)
workflow.add_edge("direct_print", END)
workflow.add_edge("direct_logo", END)
workflow.add_edge("direct_sketch", END)
workflow.add_edge("direct_design", END)
workflow.add_edge("direct_trending", END)
workflow.add_edge("direct_explor", END)
return workflow.compile()
agent = build_main_graph()
if __name__ == "__main__":
import asyncio
async def test_direct():
# 直接调用 sketch跳过 LLM
result = await agent.ainvoke(
{
"messages": [HumanMessage(content="我想设计衬衫,带有猫咪的图案")],
"call_sketch": True,
"sketch_need_prompt_generation": False,
}
)
print("=== 直接调用 sketch ===")
print(result["messages"][-1].content)
asyncio.run(test_direct())