重构图像生成和搜索工具;更新主代理来处理输入图像

- 更新了“generate_image.py”以接受输入图像以增强图像生成。
- 修改了`pexels_search.py​​`和`unsplash_search.py​​`以将日志记录和上传路径从“explorer”更改为“explore”。
- 调整了“print_graph”和“sketch_graph”以提取最新的用户输入并处理输入图像以生成打印和草图图像。
- 重构“generate_print_tool”和“generate_sketch_tool”以接受输入图像。
- 更新了“main_agent.py”以包含状态中的输入图像并调整了图形构建过程。
- 增强了“service.py”来管理输入图像并改进了流媒体期间的事件处理。
- 更新了新软件包和版本的“pyproject.toml”和“uv.lock”中的依赖项。
This commit is contained in:
zcr
2026-06-17 11:56:53 +08:00
parent 35e791b4e2
commit b9163f0b46
16 changed files with 296 additions and 212 deletions

View File

@@ -9,8 +9,8 @@ from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from langchain_core.runnables import RunnableConfig
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
from app.service.fashion_agent.init_llm import build_llm
from app.service.fashion_agent.graph_node.explore_graph.tools import explore_tool
from app.service.fashion_agent.init_llm import qwen_plus_llm
logger = logging.getLogger()
@@ -18,10 +18,10 @@ logger = logging.getLogger()
"""定义状态"""
class ExplorerState(TypedDict):
class exploreState(TypedDict):
messages: Required[Annotated[list[AnyMessage], add_messages]]
input_text: str
search_query: str
input_text: str = ""
search_query: str = ""
image_results: list[dict] # 每项包含 image_url 和 minio_path
provider: str = "unsplash" # 图片源: "pexels" 或 "unsplash"
@@ -29,71 +29,70 @@ class ExplorerState(TypedDict):
"""节点"""
def extract_input_node(state: ExplorerState) -> dict:
def extract_input_node(state: exploreState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
input_text = state["messages"][-1].content if state.get("messages") else input_text
return {"input_text": input_text}
class SearchQuery(BaseModel):
"""搜索关键词"""
query: str = Field(description="用于搜索灵感图片的英文关键词简洁有力")
query: str = Field(description="用于搜索灵感图片的英文关键词,简洁有力")
# TODO 要考虑搜索图片失败或者图片不存在的情况 搜索不到 需要调整搜索词或者拆分搜索词最终失败的话调用mood board生成工具生成 保证绝对有图片
async def generate_query_node(state: ExplorerState) -> dict:
"""使用 LLM 分析用户输入生成搜索关键词"""
# TODO 要考虑搜索图片失败或者图片不存在的情况, 搜索不到 需要调整搜索词或者拆分搜索词,最终失败的话调用mood board生成工具生成, 保证绝对有图片
async def generate_query_node(state: exploreState) -> dict:
"""使用 LLM 分析用户输入,生成搜索关键词"""
input_text = state["input_text"]
logger.info(f"[Explorer] 用户输入: {input_text}")
llm = build_llm()
logger.info(f"[explore] 用户输入: {input_text}")
structured_llm = llm.with_structured_output(SearchQuery)
structured_llm = qwen_plus_llm.with_structured_output(SearchQuery)
messages = [
SystemMessage(content="""你是一个专业服装设计师助手
根据用户输入生成一个英文搜索关键词用于在图片库中搜索服装设计灵感图片moodboard
SystemMessage(content="""你是专业服装设计师助手.
根据用户中文需求,生成适合时尚灵感图(moodboard)图库搜索的英文关键词短句.
要求
1. 使用英文简洁有力
2. 适合搜索高质量的设计灵感图片
严格输出规则:
1. 必须返回标准JSON对象,**禁止输出任何额外文字,解释,思考,前言后语**;
2. JSON 只包含一个字段 "query",值为简洁英文搜索词;
3. 关键词简洁,适配高清时尚素材搜索;
例如
用户输入"夏季连衣裙,清新风格"
输出summer dress fresh style"""),
输出格式示例(仅允许输出如下JSON,不要加别的内容):
{"query": "summer dress fresh style"}"""),
HumanMessage(content=input_text),
]
result = structured_llm.invoke(messages)
logger.info(f"[Explorer] LLM 生成的搜索关键词: {result.query}")
return {"search_query": result.query}
search_query = result.query
logger.info(f"[explore] LLM 生成的搜索关键词: {search_query}")
return {"search_query": search_query}
async def search_and_upload_node(state: ExplorerState, config: RunnableConfig) -> dict:
async def search_and_upload_node(state: exploreState, config: RunnableConfig) -> dict:
"""使用搜索关键词获取图片并上传到 minio"""
query = state.get("search_query", "")
user_id = state.get("user_id", "agent")
provider = state.get("provider", "unsplash")
try:
results = await explor_tool.ainvoke({"query": query, "per_page": 4, "user_id": user_id, "method": provider}, config=config)
results = await explore_tool.ainvoke({"query": query, "per_page": 4, "user_id": user_id, "method": provider}, config=config)
except Exception as e:
logger.error(f"[Explorer] 搜索失败 '{query}': {e}")
logger.error(f"[explore] 搜索失败 '{query}': {e}")
results = []
return {"image_results": results}
def summarize_node(state: ExplorerState) -> dict:
def summarize_node(state: exploreState) -> dict:
"""汇总结果"""
input_text = state.get("input_text", "")
query = state.get("search_query", "")
results = state.get("image_results", [])
result_text = f"灵感探索 Moodboard\n\n"
result_text += f"基于您的需求:「{input_text}\n"
result_text += f"搜索关键词{query}\n\n"
result_text += f"已为您找到 {len(results)} 张灵感图片\n"
result_text = f"[灵感探索 Moodboard]\n\n"
result_text += f"基于您的需求: {input_text}\n\n"
result_text += f"搜索关键词:{query}\n\n"
result_text += f"已为您找到 {len(results)} 张灵感图片:\n"
for i, item in enumerate(results, 1):
result_text += f" {i}. 原图: {item.get('image_url', '')}\n"
@@ -105,9 +104,9 @@ def summarize_node(state: ExplorerState) -> dict:
"""构建图"""
def build_explorer_graph():
def build_explore_graph():
"""构建灵感探索图"""
workflow = StateGraph(ExplorerState)
workflow = StateGraph(exploreState)
workflow.add_node("extract_input", extract_input_node)
workflow.add_node("generate_query", generate_query_node)
@@ -126,10 +125,10 @@ def build_explorer_graph():
if __name__ == "__main__":
async def test():
graph = build_explorer_graph()
graph = build_explore_graph()
result = await graph.ainvoke(
{
"messages": [HumanMessage(content="夏季连衣裙清新自然风格")],
"messages": [HumanMessage(content="夏季连衣裙,清新自然风格")],
"provider": "unsplash",
}
)

View File

@@ -15,7 +15,7 @@ class SearchInput(BaseModel):
@tool(args_schema=SearchInput)
async def explor_tool(
async def explore_tool(
query: str, per_page: int = 4, user_id: str = "agent", method: str = "unsplash", config: RunnableConfig = None
) -> list[dict]:
"""Search for fashion inspiration images on Unsplash and upload to minio. Returns a list of dicts with image_url and minio_path."""

View File

@@ -7,7 +7,7 @@ from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
from app.service.fashion_agent.init_llm import qwen_plus_llm
from app.service.fashion_agent.init_llm import build_llm, qwen_plus_llm
"""初始化 LLM TODO 将 API Key 替换为环境变量或者配置文件中的值,避免在代码中硬编码敏感信息"""
@@ -17,7 +17,7 @@ from app.service.fashion_agent.init_llm import qwen_plus_llm
class LogoState(TypedDict):
messages: Required[Annotated[list[AnyMessage], add_messages]]
input_text: str
input_text: str = ""
user_id: str = "agent"
role: str = ""
gender: str = ""
@@ -35,14 +35,14 @@ class LogoState(TypedDict):
# 定义输出结构
class LogoPrompt(BaseModel):
"""生成的 Logo 图像提示词"""
"""logo image generation diagram prompt words"""
prompts: list[str] = Field(description="用于生成 Logo 的详细提示词")
prompts: list[str] = Field(description="Array of prompt words, simple English words")
def extract_input_node(state: LogoState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
input_text = state["messages"][-1].content if state.get("messages") else state["input_text"]
return {"input_text": input_text}
@@ -51,18 +51,19 @@ def generate_logo_prompt_node(state: LogoState) -> dict:
structured_llm = qwen_plus_llm.with_structured_output(LogoPrompt)
messages = [
SystemMessage(content="""从用户输入中提取核心主题词,只输出一个简单的英文单词。
SystemMessage(content="""从用户输入中提取核心主题词,一个简单的英文单词作为 prompts 字段的值
例如:
- "我想要一个猫咪图案" -> "cat"
- "设计一个花朵" -> "flower"
- "可爱的狗" -> "dog"
只输出单词,不要其他内容。"""),
- "我想要一个猫咪图案" -> {"prompts": ["cat"]}
- "设计一个花朵" -> {"prompts": ["flower"]}
- "可爱的狗" -> {"prompts": ["dog"]}
请严格按照 JSON 格式输出,包含 prompts 字段。
"""),
HumanMessage(content=state["input_text"]),
]
result = structured_llm.invoke(messages)
prompts = result.prompts
print(result)
return {
"logo_prompts": prompts,
}
@@ -139,14 +140,14 @@ async def main(test_input, user_id="agent", need_prompt_generation=True):
if __name__ == "__main__":
# 测试示例 1: 需要 prompt 生成(默认)- 简单关键词输入
test_input = "我想要一个金毛图案"
result = asyncio.run(main(test_input, need_prompt_generation=True))
print("=== 需要 prompt 生成 ===")
print(f"Result: {result}")
# 测试示例 2: 直接使用用户提供的 prompt
user_prompt = "golden retriever"
result = asyncio.run(main(user_prompt, need_prompt_generation=False))
print("\n=== 直接使用 prompt ===")
print(f"Result: {result}")
async def test():
graph = build_logo_graph()
result = await graph.ainvoke(
{
"messages": [HumanMessage(content="我想要一个金毛图案")],
}
)
print(result["messages"][-1].content)
asyncio.run(test())

View File

@@ -4,13 +4,14 @@ import httpx
async def generate_image(
bucket_name="fida-public-bucket",
object_name=f"furniture/sketches/123456.png",
input_images=[],
prompt="Generate a modern minimalist dining chair made of light "
"oak wood and white leather, with slim metal legs, photographed "
"in a bright Scandinavian living room with natural sunlight, high detail, "
"8k resolution.",
):
request_data = {
"input_image_paths": [],
"input_image_paths": input_images,
"prompt": prompt,
"bucket_name": bucket_name,
"object_name": object_name,

View File

@@ -63,10 +63,10 @@ async def search_photos(query: str, per_page: int = 4, user_id: str = "agent") -
# 上传到 minio使用线程池避免阻塞事件循环
file_name = f"{uuid7()}.jpg"
loop = asyncio.get_event_loop()
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explore", file_name)
results.append({"image_url": image_url, "minio_path": minio_url})
logger.info(f"[Explorer] 上传成功: {minio_url}")
logger.info(f"[explore] 上传成功: {minio_url}")
except Exception as e:
logger.error(f"[Explorer] 上传失败: {e}")
logger.error(f"[explore] 上传失败: {e}")
return results

View File

@@ -63,11 +63,11 @@ async def get_random_photos(query: str, count: int = 4, user_id: str = "agent")
# 上传到 minio使用线程池避免阻塞事件循环
file_name = f"{uuid7()}.jpg"
loop = asyncio.get_event_loop()
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explore", file_name)
results.append({"image_url": image_url, "minio_path": minio_url})
logger.info(f"[Explorer] 上传成功: {minio_url}")
logger.info(f"[explore] 上传成功: {minio_url}")
except Exception as e:
logger.error(f"[Explorer] 上传失败: {e}")
logger.error(f"[explore] 上传失败: {e}")
return results

View File

@@ -40,7 +40,7 @@ class PrintPrompt(BaseModel):
def extract_input_node(state: PrintState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
input_text = state["messages"][-1].content if state.get("messages") else ""
return {"input_text": input_text}
@@ -49,13 +49,13 @@ def generate_print_prompt_node(state: PrintState) -> dict:
structured_llm = qwen_plus_llm.with_structured_output(PrintPrompt)
messages = [
SystemMessage(content=f"""你是一个专业的印花图案设计师
请根据用户输入生成用于AI图像生成的印花图案提示词
SystemMessage(content=f"""你是一个专业的印花图案设计师.
请根据用户输入,生成用于AI图像生成的印花图案提示词.
要求
1. 提示词应该详细描述印花图案的样式元素颜色布局
要求:
1. 提示词应该详细描述印花图案的样式,元素,颜色,布局
2. 提示词应该适合用于 Stable Diffusion 图像生成模型
3. 提示词应该使用英文因为图像生成模型对英文理解更好
3. 提示词应该使用英文,因为图像生成模型对英文理解更好
4. 提示词数量为 {state.get("print_num", 1)}
"""),
HumanMessage(content=state["input_text"]),
@@ -73,17 +73,19 @@ def generate_print_prompt_node(state: PrintState) -> dict:
async def generate_print_img_node(state: PrintState) -> dict:
"""根据生成的提示词生成印花图案"""
# 如果 print_prompts 为空使用 input_text 作为 prompt
"""根据生成的提示词,生成印花图案"""
# 如果 print_prompts 为空,使用 input_text 作为 prompt
if state.get("print_need_prompt_generation", False):
prompts = state["print_prompts"] if state["print_prompts"] else [state["input_text"]]
else:
input_text = state.get("input_text", "")
prompts = [input_text]
input_images = state.get("input_images", [])
print_img_urls = []
for prompt in prompts:
image_url = await generate_print_tool.ainvoke({"prompt": prompt})
image_url = await generate_print_tool.ainvoke({"prompt": prompt, "input_images": input_images})
print_img_urls.append(image_url)
logger.info(f"[Print Graph] Generated print image URL: {image_url}")
@@ -94,7 +96,7 @@ async def generate_print_img_node(state: PrintState) -> dict:
def should_generate_prompt(state: PrintState) -> str:
"""条件分支判断是否需要生成 prompt"""
"""条件分支:判断是否需要生成 prompt"""
logger.info(
f"[Print Graph] should_generate_prompt: print_need_prompt_generation={state.get('print_need_prompt_generation')}, print_prompts={state.get('print_prompts')}"
@@ -106,7 +108,6 @@ def should_generate_prompt(state: PrintState) -> str:
def build_print_graph():
workflow = StateGraph(PrintState)
workflow.add_node("extract_input", extract_input_node)
workflow.add_node("gen_prompt", generate_print_prompt_node)
@@ -145,8 +146,8 @@ async def main(test_input, print_need_prompt_generation=True):
if __name__ == "__main__":
# 测试示例 1: 需要 prompt 生成默认
test_input = "我想要一个优雅的花卉印花适合用于连衣裙颜色以粉色和白色为主"
# 测试示例 1: 需要 prompt 生成(默认)
test_input = "我想要一个优雅的花卉印花,适合用于连衣裙,颜色以粉色和白色为主"
result = asyncio.run(main(test_input, print_need_prompt_generation=True))
print("=== 需要 prompt 生成 ===")
print(f"Result: {result}")

View File

@@ -11,15 +11,16 @@ class GenerateImageToolInput(BaseModel):
"""Input schema for the Generate Image Tool."""
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
input_images: list[str] = Field(default=[], description="Input images for the generation.")
@tool(args_schema=GenerateImageToolInput)
async def generate_print_tool(prompt: str) -> str:
async def generate_print_tool(prompt: str, input_images: list[str]) -> str:
"""Generate an image based on the provided prompt."""
bucket_name = "aida-users"
object_name = f"agent_generate_print/{uuid7()}.png"
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name, input_images=input_images)
return [image_url]

View File

@@ -42,7 +42,7 @@ class SketchPrompt(BaseModel):
def extract_input_node(state: SketchState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
input_text = state["messages"][-1].content if state.get("messages") else ""
return {"input_text": input_text}
@@ -80,15 +80,16 @@ async def generate_sketch_img_node(state: SketchState) -> dict:
"""根据生成的提示词,生成服装草图"""
# 如果 sketch_need_prompt_generation=False 且 sketch_prompts 为空,使用模板生成 prompt
if not state.get("sketch_need_prompt_generation", False) and not state.get("sketch_prompts"):
input_text = state.get("input_text", "")
prompts = [build_sketch_template_prompt(input_text)]
else:
prompts = state["sketch_prompts"] if state["sketch_prompts"] else [state["input_text"]]
input_images = state.get("input_images", [])
sketch_img_urls = []
for prompt in prompts:
image_url = await generate_sketch_tool.ainvoke({"prompt": prompt})
image_url = await generate_sketch_tool.ainvoke({"prompt": prompt, "input_images": input_images})
sketch_img_urls.append(image_url)
return {"sketch_img_urls": sketch_img_urls}
@@ -107,33 +108,26 @@ def should_generate_prompt(state: SketchState) -> str:
def build_sketch_graph():
workflow = StateGraph(SketchState)
workflow.add_node("extract_input", extract_input_node)
workflow.add_node("gen_prompt", generate_sketch_prompt_node)
workflow.add_node("gen_sketch", generate_sketch_img_node)
workflow.add_edge(START, "gen_sketch")
# 添加边
workflow.add_edge(START, "extract_input")
workflow.add_conditional_edges(
"extract_input",
should_generate_prompt,
{
"gen_prompt": "gen_prompt",
"gen_sketch": "gen_sketch",
},
)
workflow.add_edge("gen_prompt", "gen_sketch")
workflow.add_edge("gen_sketch", END)
graph = workflow.compile()
return graph
# workflow = StateGraph(SketchState)
# workflow.add_node("extract_input", extract_input_node)
# workflow.add_node("gen_prompt", generate_sketch_prompt_node)
# workflow.add_node("gen_sketch", generate_sketch_img_node)
# # 添加边
# workflow.add_edge(START, "extract_input")
# workflow.add_conditional_edges(
# "extract_input",
# should_generate_prompt,
# {
# "gen_prompt": "gen_prompt",
# "gen_sketch": "gen_sketch",
# },
# )
# workflow.add_edge("gen_prompt", "gen_sketch")
# workflow.add_edge("gen_sketch", END)
# graph = workflow.compile()
# return graph
def build_sketch_template_prompt(input_text: str) -> str:
"""构建 sketch prompt 模板"""

View File

@@ -11,15 +11,16 @@ class GenerateImageToolInput(BaseModel):
"""Input schema for the Generate Image Tool."""
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
input_images: list[str] = Field(default=[], description="Input images for the generation.")
@tool(args_schema=GenerateImageToolInput)
async def generate_sketch_tool(prompt: str) -> str:
async def generate_sketch_tool(prompt: str, input_images: list[str]) -> str:
"""Generate an image based on the provided prompt."""
bucket_name = "fida-public-bucket"
object_name = f"test/{uuid7()}.png"
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name, input_images=input_images)
return [image_url]

View File

@@ -1,16 +1,11 @@
import sys
from pathlib import Path
from typing import Annotated, Required, TypedDict
from deepagents import CompiledSubAgent, create_deep_agent
from langchain.agents import create_agent
from langchain.tools import tool
from langchain_core.messages import AnyMessage, HumanMessage
from langchain_qwq import ChatQwen
from langchain_core.messages import AnyMessage
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from app.service.fashion_agent.graph_node.design_graph.graph import build_design_graph
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
from app.service.fashion_agent.graph_node.explore_graph.tools import explore_tool
from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph
@@ -18,7 +13,7 @@ from app.service.fashion_agent.graph_node.print_graph.tools import generate_prin
from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
from app.service.fashion_agent.graph_node.trending_graph.trending_graph import build_trending_graph
from app.service.fashion_agent.graph_node.explorer_graph.graph import build_explorer_graph
from app.service.fashion_agent.graph_node.explore_graph.graph import build_explore_graph
from app.service.fashion_agent.init_llm import build_llm
print_graph = build_print_graph()
@@ -26,13 +21,16 @@ logo_graph = build_logo_graph()
sketch_graph = build_sketch_graph()
design_graph = build_design_graph()
trending_graph = build_trending_graph()
explorer_graph = build_explorer_graph()
explore_graph = build_explore_graph()
class MainState(TypedDict):
# 消息
messages: Required[Annotated[list[AnyMessage], add_messages]]
# 上传图片
input_images: list[str] = []
# 模块控制
call_design: bool = False
call_print: bool = False
@@ -40,7 +38,7 @@ class MainState(TypedDict):
call_sketch: bool = False
call_design: bool = False
call_trending: bool = False
call_explor: bool = False
call_explore: bool = False
# design参数
design_request_data: dict = {}
@@ -58,7 +56,7 @@ class MainState(TypedDict):
print_img_urls: list[str] = []
tools = [explor_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
tools = [explore_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
def route_node(state: MainState) -> str:
@@ -73,12 +71,12 @@ def route_node(state: MainState) -> str:
return "direct_design"
if state.get("call_trending"):
return "direct_trending"
if state.get("call_explor"):
return "direct_explor"
if state.get("call_explore"):
return "direct_explore"
return "llm_agent"
def build_main_graph(enable_thinking: bool = False):
async def build_main_graph(enable_thinking: bool = False, checkpointer=None):
llm = build_llm(enable_thinking=enable_thinking)
chat_agent = create_agent(
model=llm, tools=tools, state_schema=MainState, system_prompt="你是一个专业的服装设计助手。根据用户需求,调用合适的工具完成任务."
@@ -94,7 +92,7 @@ def build_main_graph(enable_thinking: bool = False):
workflow.add_node("direct_sketch", sketch_graph)
workflow.add_node("direct_design", design_graph)
workflow.add_node("direct_trending", trending_graph)
workflow.add_node("direct_explor", explorer_graph)
workflow.add_node("direct_explore", explore_graph)
# 条件分支
workflow.add_conditional_edges(
@@ -107,7 +105,7 @@ def build_main_graph(enable_thinking: bool = False):
"direct_sketch": "direct_sketch",
"direct_design": "direct_design",
"direct_trending": "direct_trending",
"direct_explor": "direct_explor",
"direct_explore": "direct_explore",
},
)
@@ -118,27 +116,7 @@ def build_main_graph(enable_thinking: bool = False):
workflow.add_edge("direct_sketch", END)
workflow.add_edge("direct_design", END)
workflow.add_edge("direct_trending", END)
workflow.add_edge("direct_explor", END)
workflow.add_edge("direct_explore", END)
return workflow.compile()
agent = build_main_graph()
if __name__ == "__main__":
import asyncio
async def test_direct():
# 直接调用 sketch跳过 LLM
result = await agent.ainvoke(
{
"messages": [HumanMessage(content="我想设计衬衫,带有猫咪的图案")],
"call_sketch": True,
"sketch_need_prompt_generation": False,
}
)
print("=== 直接调用 sketch ===")
print(result["messages"][-1].content)
asyncio.run(test_direct())
graph = workflow.compile(checkpointer=checkpointer)
return graph

View File

@@ -1,13 +1,12 @@
import json
import logging
import sys
from pathlib import Path
from langgraph.stream import ProtocolEvent, StreamChannel, StreamTransformer
from app.service.fashion_agent.main_agent import build_main_graph
from langgraph.prebuilt import ToolCallTransformer
from typing import AsyncGenerator, TypedDict
from typing import AsyncGenerator
from langchain_core.messages import HumanMessage, ToolMessage
from app.schemas.fashion_agent import FashionAgentRequest
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
logger = logging.getLogger()
@@ -33,73 +32,83 @@ class FashionAgentService:
async def run_stream(self, request: FashionAgentRequest) -> AsyncGenerator[str, None]:
"""流式运行 agent - 使用 v3 projections"""
config = {"configurable": {"user_id": request.user_id}}
config = {"configurable": {"thread_id": request.thread_id, "user_id": request.user_id}}
agent = build_main_graph(enable_thinking=request.enable_thinking)
state = {
"messages": [HumanMessage(content=request.message)],
"call_print": request.call_print,
"call_logo": request.call_logo,
"call_sketch": request.call_sketch,
"call_design": request.call_design,
"call_trending": request.call_trending,
"call_explor": request.call_explor,
"print_need_prompt_generation": request.print_need_prompt_generation,
"sketch_need_prompt_generation": request.sketch_need_prompt_generation,
"design_request_data": request.design_request_data,
}
async with AsyncPostgresSaver.from_conn_string("postgresql://postgres:Aidlab123123!@20.1.1.43:15432/myapp_prod") as checkpointer:
await checkpointer.setup()
agent = await build_main_graph(enable_thinking=request.enable_thinking, checkpointer=checkpointer)
stream = await agent.astream_events(state, config=config, version="v3", transformers=[ToolCallTransformer, CustomTransformer])
state = {
"messages": [HumanMessage(content=request.message)],
"input_images": request.input_images,
"call_print": request.call_print,
"call_logo": request.call_logo,
"call_sketch": request.call_sketch,
"call_design": request.call_design,
"call_trending": request.call_trending,
"call_explore": request.call_explore,
"print_need_prompt_generation": request.print_need_prompt_generation,
"sketch_need_prompt_generation": request.sketch_need_prompt_generation,
"design_request_data": request.design_request_data,
}
tool_names = {}
filter_tool_name = ["design_tool"]
async for event in stream:
if event["method"] == "tools":
data = event["params"]["data"]
tool_call_id = data.get("tool_call_id")
stream = await agent.astream_events(state, config=config, version="v3", transformers=[ToolCallTransformer, CustomTransformer])
# 记录 tool_name
if data.get("event") == "tool-started":
tool_names[tool_call_id] = data.get("tool_name")
tool_names = {}
filter_tool_name = ["design_tool"]
async for event in stream:
if event["method"] == "tools":
data = event["params"]["data"]
tool_call_id = data.get("tool_call_id")
# 通过 ID 查找 tool_name
elif data.get("event") == "tool-finished":
tool_name = tool_names.get(tool_call_id, "unknown")
# 记录 tool_name
if data.get("event") == "tool-started":
tool_names[tool_call_id] = data.get("tool_name")
if tool_name in filter_tool_name:
# 通过 ID 查找 tool_name
elif data.get("event") == "tool-finished":
tool_name = tool_names.get(tool_call_id, "unknown")
if tool_name in filter_tool_name:
continue
data["tool_name"] = tool_name
if isinstance(data["output"], ToolMessage):
data["output"] = json.loads(data["output"].content)
response_event = {"event_type": "tool", "data": data}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
elif event["method"] == "custom":
data = event["params"]["data"]
response_event = {"event_type": "tool", "data": data}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
elif event["method"] == "messages":
event_data = event["params"]["data"]
data = event_data[0] if len(event_data) > 0 else {}
# 提取元数据 (如果有的话)
metadata = event_data[1] if len(event_data) > 1 else {}
if not isinstance(data, dict):
continue
if metadata.get("langgraph_node") in {"gen_prompt", "generate_query"}:
continue
data["tool_name"] = tool_name
ev = data.get("event")
if isinstance(data["output"], ToolMessage):
data["output"] = json.loads(data["output"].content)
response_event = {"event_type": "tool", "data": data}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
if ev == "content-block-delta":
block = data.get("delta") or {}
if block.get("type") in ("text-delta", "reasoning-delta"):
response_event = {"event_type": "messages", "data": {"event": ev} | block}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
elif event["method"] == "custom":
data = event["params"]["data"]
response_event = {"event_type": "tool", "data": data}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
elif ev in ("message-start", "content-block-start", "content-block-finish", "message-finish"):
response_event = {"event_type": "messages", "data": {"event": ev} | data}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
elif event["method"] == "messages":
data = event["params"]["data"][0]
if not isinstance(data, dict):
continue
if data.get("event") != "content-block-delta":
continue
block = data.get("delta") or {}
if block.get("type") == "text-delta":
response_event = {"event_type": "messages", "data": {"event": data["event"]} | block}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
elif block.get("type") == "reasoning-delta":
response_event = {"event_type": "messages", "data": {"event": data["event"]} | block}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
else:
pass
# print(f"----------------{event}")
response_event = {"event_type": "done"}
yield f"data: {response_event}"
response_event = {"event_type": "done"}
yield f"data: {response_event}"
if __name__ == "__main__":
@@ -117,13 +126,15 @@ if __name__ == "__main__":
print("测试流式输出")
print("=" * 50)
request = FashionAgentRequest(
message="生成一张草莓图案",
call_print=True,
thread_id="zhh",
message="落日",
# call_print=True,
# input_images=["test/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png"],
# print_need_prompt_generation=False,
# call_sketch=True,
# sketch_need_prompt_generation=False,
# call_logo=True,
# call_explor=True,
call_explore=True,
# call_design=True,
# design_request_data=request_data,
)