From 35e791b4e22ec3af9e7fdc4fa8551e22b618392b Mon Sep 17 00:00:00 2001 From: zcr Date: Mon, 15 Jun 2026 17:10:04 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=9B=BE=E5=BD=A2=E7=94=9F?= =?UTF-8?q?=E6=88=90=E5=B7=A5=E5=85=B7=EF=BC=8C=E4=BC=98=E5=8C=96=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E6=A0=BC=E5=BC=8F=E5=B9=B6=E6=B7=BB=E5=8A=A0=E6=96=B0?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../graph_node/explorer_graph/tools.py | 8 ++++--- .../graph_node/logo_graph/graph.py | 2 +- .../generate_logo.py => logo_graph/tools.py} | 2 +- .../graph_node/print_graph/tools.py | 2 +- .../graph_node/sketch_graph/graph.py | 22 +++++++++---------- .../graph_node/sketch_graph/tools.py | 2 +- app/service/fashion_agent/init_llm.py | 2 ++ app/service/fashion_agent/main_agent.py | 2 +- app/service/fashion_agent/service.py | 4 +++- pyproject.toml | 2 ++ uv.lock | 4 ++++ 11 files changed, 31 insertions(+), 21 deletions(-) rename app/service/fashion_agent/graph_node/{node_tools/generate_logo.py => logo_graph/tools.py} (99%) diff --git a/app/service/fashion_agent/graph_node/explorer_graph/tools.py b/app/service/fashion_agent/graph_node/explorer_graph/tools.py index 54afc49..5e5a3ce 100644 --- a/app/service/fashion_agent/graph_node/explorer_graph/tools.py +++ b/app/service/fashion_agent/graph_node/explorer_graph/tools.py @@ -23,12 +23,14 @@ async def explor_tool( # 方式 1:从 configurable 获取 user_id = config.get("configurable", {}).get("user_id", "agent") + results = [] if method == "unsplash": - return await get_random_photos(query, count=per_page, user_id=user_id) + results = await get_random_photos(query, count=per_page, user_id=user_id) elif method == "pexels": - return await search_photos(query, per_page=per_page, user_id=user_id) + results = await search_photos(query, per_page=per_page, user_id=user_id) else: - pass + results = [] + return results if __name__ == "__main__": diff --git a/app/service/fashion_agent/graph_node/logo_graph/graph.py b/app/service/fashion_agent/graph_node/logo_graph/graph.py index 6aa9c22..080849c 100644 --- a/app/service/fashion_agent/graph_node/logo_graph/graph.py +++ b/app/service/fashion_agent/graph_node/logo_graph/graph.py @@ -6,7 +6,7 @@ from langchain_core.messages import HumanMessage, SystemMessage from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from pydantic import BaseModel, Field -from app.service.fashion_agent.graph_node.node_tools.generate_logo import generate_logo_tool +from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool from app.service.fashion_agent.init_llm import qwen_plus_llm """初始化 LLM TODO 将 API Key 替换为环境变量或者配置文件中的值,避免在代码中硬编码敏感信息""" diff --git a/app/service/fashion_agent/graph_node/node_tools/generate_logo.py b/app/service/fashion_agent/graph_node/logo_graph/tools.py similarity index 99% rename from app/service/fashion_agent/graph_node/node_tools/generate_logo.py rename to app/service/fashion_agent/graph_node/logo_graph/tools.py index 3b704ab..81162bb 100644 --- a/app/service/fashion_agent/graph_node/node_tools/generate_logo.py +++ b/app/service/fashion_agent/graph_node/logo_graph/tools.py @@ -71,7 +71,7 @@ async def generate_logo_tool(prompt: str, user_id: str = "agent") -> str: file_name = f"{uuid7()}.png" loop = asyncio.get_event_loop() image_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "logo", file_name) - return image_url + return [image_url] if __name__ == "__main__": diff --git a/app/service/fashion_agent/graph_node/print_graph/tools.py b/app/service/fashion_agent/graph_node/print_graph/tools.py index bf91e37..86be8f8 100644 --- a/app/service/fashion_agent/graph_node/print_graph/tools.py +++ b/app/service/fashion_agent/graph_node/print_graph/tools.py @@ -20,7 +20,7 @@ async def generate_print_tool(prompt: str) -> str: bucket_name = "aida-users" object_name = f"agent_generate_print/{uuid7()}.png" image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name) - return image_url + return [image_url] @tool diff --git a/app/service/fashion_agent/graph_node/sketch_graph/graph.py b/app/service/fashion_agent/graph_node/sketch_graph/graph.py index 4d44395..5f251fb 100644 --- a/app/service/fashion_agent/graph_node/sketch_graph/graph.py +++ b/app/service/fashion_agent/graph_node/sketch_graph/graph.py @@ -79,21 +79,19 @@ def generate_sketch_prompt_node(state: SketchState) -> dict: async def generate_sketch_img_node(state: SketchState) -> dict: """根据生成的提示词,生成服装草图""" # 如果 sketch_need_prompt_generation=False 且 sketch_prompts 为空,使用模板生成 prompt - # if not state.get("sketch_need_prompt_generation", False) and not state.get("sketch_prompts"): + if not state.get("sketch_need_prompt_generation", False) and not state.get("sketch_prompts"): - # input_text = state.get("input_text", "") - # prompts = [build_sketch_template_prompt(input_text)] - # else: - # prompts = state["sketch_prompts"] if state["sketch_prompts"] else [state["input_text"]] + input_text = state.get("input_text", "") + prompts = [build_sketch_template_prompt(input_text)] + else: + prompts = state["sketch_prompts"] if state["sketch_prompts"] else [state["input_text"]] - # sketch_img_urls = [] - # for prompt in prompts: - # image_url = await generate_sketch_tool.ainvoke({"prompt": prompt}) - # sketch_img_urls.append(image_url) + sketch_img_urls = [] + for prompt in prompts: + image_url = await generate_sketch_tool.ainvoke({"prompt": prompt}) + sketch_img_urls.append(image_url) - # result_text = f"服装草图生成完成,共生成 {len(sketch_img_urls)} 张图片:\n" + "\n".join(sketch_img_urls) - # return {"sketch_img_urls": sketch_img_urls, "messages": [AIMessage(content=result_text)]} - return {"messages": [AIMessage(content="hello")]} + return {"sketch_img_urls": sketch_img_urls} """条件分支 判断是否需要生成 prompt""" diff --git a/app/service/fashion_agent/graph_node/sketch_graph/tools.py b/app/service/fashion_agent/graph_node/sketch_graph/tools.py index 7393583..7142b34 100644 --- a/app/service/fashion_agent/graph_node/sketch_graph/tools.py +++ b/app/service/fashion_agent/graph_node/sketch_graph/tools.py @@ -20,7 +20,7 @@ async def generate_sketch_tool(prompt: str) -> str: bucket_name = "fida-public-bucket" object_name = f"test/{uuid7()}.png" image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name) - return image_url + return [image_url] async def run_test(): diff --git a/app/service/fashion_agent/init_llm.py b/app/service/fashion_agent/init_llm.py index ca053c2..01c2a9e 100644 --- a/app/service/fashion_agent/init_llm.py +++ b/app/service/fashion_agent/init_llm.py @@ -29,3 +29,5 @@ qwen_plus_llm = ChatQwen( top_p=0.8, api_key=QWEN_API_KEY_INTL, ) +# response = qwen_plus_llm.invoke("你好") +# print(response) diff --git a/app/service/fashion_agent/main_agent.py b/app/service/fashion_agent/main_agent.py index 36a2383..e112733 100644 --- a/app/service/fashion_agent/main_agent.py +++ b/app/service/fashion_agent/main_agent.py @@ -12,7 +12,7 @@ from app.service.fashion_agent.graph_node.design_graph.graph import build_design from app.service.fashion_agent.graph_node.design_graph.tools import design_tool from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph -from app.service.fashion_agent.graph_node.node_tools.generate_logo import generate_logo_tool +from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph diff --git a/app/service/fashion_agent/service.py b/app/service/fashion_agent/service.py index bf14f4d..7df6e70 100644 --- a/app/service/fashion_agent/service.py +++ b/app/service/fashion_agent/service.py @@ -6,7 +6,7 @@ from langgraph.stream import ProtocolEvent, StreamChannel, StreamTransformer from app.service.fashion_agent.main_agent import build_main_graph from langgraph.prebuilt import ToolCallTransformer from typing import AsyncGenerator, TypedDict -from langchain_core.messages import HumanMessage +from langchain_core.messages import HumanMessage, ToolMessage from app.schemas.fashion_agent import FashionAgentRequest logger = logging.getLogger() @@ -71,6 +71,8 @@ class FashionAgentService: data["tool_name"] = tool_name + if isinstance(data["output"], ToolMessage): + data["output"] = json.loads(data["output"].content) response_event = {"event_type": "tool", "data": data} yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n" diff --git a/pyproject.toml b/pyproject.toml index 89c6d08..3748643 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,10 +20,12 @@ dependencies = [ "image>=1.5.33", "langchain>=1.2.0", "langchain-community>=0.4.1", + "langchain-openai>=1.2.2", "langchain-qwq>=0.3.5", "langgraph>=1.0.5", "langgraph-api>=0.4.28", "langgraph-cli[inmem,redis]<=0.4.26", + "langsmith>=0.8.11", "load>=1.0.14", "load-dotenv>=0.1.0", "loguru>=0.7.3", diff --git a/uv.lock b/uv.lock index 44e39c1..87e4493 100755 --- a/uv.lock +++ b/uv.lock @@ -3739,10 +3739,12 @@ dependencies = [ { name = "image" }, { name = "langchain" }, { name = "langchain-community" }, + { name = "langchain-openai" }, { name = "langchain-qwq" }, { name = "langgraph" }, { name = "langgraph-api" }, { name = "langgraph-cli", extra = ["inmem"] }, + { name = "langsmith" }, { name = "load" }, { name = "load-dotenv" }, { name = "loguru" }, @@ -3796,10 +3798,12 @@ requires-dist = [ { name = "image", specifier = ">=1.5.33" }, { name = "langchain", specifier = ">=1.2.0" }, { name = "langchain-community", specifier = ">=0.4.1" }, + { name = "langchain-openai", specifier = ">=1.2.2" }, { name = "langchain-qwq", specifier = ">=0.3.5" }, { name = "langgraph", specifier = ">=1.0.5" }, { name = "langgraph-api", specifier = ">=0.4.28" }, { name = "langgraph-cli", extras = ["inmem", "redis"], specifier = "<=0.4.26" }, + { name = "langsmith", specifier = ">=0.8.11" }, { name = "load", specifier = ">=1.0.14" }, { name = "load-dotenv", specifier = ">=0.1.0" }, { name = "loguru", specifier = ">=0.7.3" },