重构图像生成和搜索工具;更新主代理来处理输入图像
- 更新了“generate_image.py”以接受输入图像以增强图像生成。 - 修改了`pexels_search.py`和`unsplash_search.py`以将日志记录和上传路径从“explorer”更改为“explore”。 - 调整了“print_graph”和“sketch_graph”以提取最新的用户输入并处理输入图像以生成打印和草图图像。 - 重构“generate_print_tool”和“generate_sketch_tool”以接受输入图像。 - 更新了“main_agent.py”以包含状态中的输入图像并调整了图形构建过程。 - 增强了“service.py”来管理输入图像并改进了流媒体期间的事件处理。 - 更新了新软件包和版本的“pyproject.toml”和“uv.lock”中的依赖项。
This commit is contained in:
@@ -1,16 +1,11 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Required, TypedDict
|
||||
from deepagents import CompiledSubAgent, create_deep_agent
|
||||
from langchain.agents import create_agent
|
||||
from langchain.tools import tool
|
||||
from langchain_core.messages import AnyMessage, HumanMessage
|
||||
from langchain_qwq import ChatQwen
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from app.service.fashion_agent.graph_node.design_graph.graph import build_design_graph
|
||||
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool
|
||||
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
|
||||
from app.service.fashion_agent.graph_node.explore_graph.tools import explore_tool
|
||||
from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph
|
||||
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
|
||||
from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph
|
||||
@@ -18,7 +13,7 @@ from app.service.fashion_agent.graph_node.print_graph.tools import generate_prin
|
||||
from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph
|
||||
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
|
||||
from app.service.fashion_agent.graph_node.trending_graph.trending_graph import build_trending_graph
|
||||
from app.service.fashion_agent.graph_node.explorer_graph.graph import build_explorer_graph
|
||||
from app.service.fashion_agent.graph_node.explore_graph.graph import build_explore_graph
|
||||
from app.service.fashion_agent.init_llm import build_llm
|
||||
|
||||
print_graph = build_print_graph()
|
||||
@@ -26,13 +21,16 @@ logo_graph = build_logo_graph()
|
||||
sketch_graph = build_sketch_graph()
|
||||
design_graph = build_design_graph()
|
||||
trending_graph = build_trending_graph()
|
||||
explorer_graph = build_explorer_graph()
|
||||
explore_graph = build_explore_graph()
|
||||
|
||||
|
||||
class MainState(TypedDict):
|
||||
# 消息
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
|
||||
# 上传图片
|
||||
input_images: list[str] = []
|
||||
|
||||
# 模块控制
|
||||
call_design: bool = False
|
||||
call_print: bool = False
|
||||
@@ -40,7 +38,7 @@ class MainState(TypedDict):
|
||||
call_sketch: bool = False
|
||||
call_design: bool = False
|
||||
call_trending: bool = False
|
||||
call_explor: bool = False
|
||||
call_explore: bool = False
|
||||
|
||||
# design参数
|
||||
design_request_data: dict = {}
|
||||
@@ -58,7 +56,7 @@ class MainState(TypedDict):
|
||||
print_img_urls: list[str] = []
|
||||
|
||||
|
||||
tools = [explor_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
|
||||
tools = [explore_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
|
||||
|
||||
|
||||
def route_node(state: MainState) -> str:
|
||||
@@ -73,12 +71,12 @@ def route_node(state: MainState) -> str:
|
||||
return "direct_design"
|
||||
if state.get("call_trending"):
|
||||
return "direct_trending"
|
||||
if state.get("call_explor"):
|
||||
return "direct_explor"
|
||||
if state.get("call_explore"):
|
||||
return "direct_explore"
|
||||
return "llm_agent"
|
||||
|
||||
|
||||
def build_main_graph(enable_thinking: bool = False):
|
||||
async def build_main_graph(enable_thinking: bool = False, checkpointer=None):
|
||||
llm = build_llm(enable_thinking=enable_thinking)
|
||||
chat_agent = create_agent(
|
||||
model=llm, tools=tools, state_schema=MainState, system_prompt="你是一个专业的服装设计助手。根据用户需求,调用合适的工具完成任务."
|
||||
@@ -94,7 +92,7 @@ def build_main_graph(enable_thinking: bool = False):
|
||||
workflow.add_node("direct_sketch", sketch_graph)
|
||||
workflow.add_node("direct_design", design_graph)
|
||||
workflow.add_node("direct_trending", trending_graph)
|
||||
workflow.add_node("direct_explor", explorer_graph)
|
||||
workflow.add_node("direct_explore", explore_graph)
|
||||
|
||||
# 条件分支
|
||||
workflow.add_conditional_edges(
|
||||
@@ -107,7 +105,7 @@ def build_main_graph(enable_thinking: bool = False):
|
||||
"direct_sketch": "direct_sketch",
|
||||
"direct_design": "direct_design",
|
||||
"direct_trending": "direct_trending",
|
||||
"direct_explor": "direct_explor",
|
||||
"direct_explore": "direct_explore",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -118,27 +116,7 @@ def build_main_graph(enable_thinking: bool = False):
|
||||
workflow.add_edge("direct_sketch", END)
|
||||
workflow.add_edge("direct_design", END)
|
||||
workflow.add_edge("direct_trending", END)
|
||||
workflow.add_edge("direct_explor", END)
|
||||
workflow.add_edge("direct_explore", END)
|
||||
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
agent = build_main_graph()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def test_direct():
|
||||
# 直接调用 sketch,跳过 LLM
|
||||
result = await agent.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="我想设计衬衫,带有猫咪的图案")],
|
||||
"call_sketch": True,
|
||||
"sketch_need_prompt_generation": False,
|
||||
}
|
||||
)
|
||||
print("=== 直接调用 sketch ===")
|
||||
print(result["messages"][-1].content)
|
||||
|
||||
asyncio.run(test_direct())
|
||||
graph = workflow.compile(checkpointer=checkpointer)
|
||||
return graph
|
||||
|
||||
Reference in New Issue
Block a user