import json import logging from langgraph.stream import ProtocolEvent, StreamChannel, StreamTransformer from app.service.fashion_agent.main_agent import build_main_graph from langgraph.prebuilt import ToolCallTransformer from typing import AsyncGenerator from langchain_core.messages import HumanMessage, ToolMessage from app.schemas.fashion_agent import FashionAgentRequest from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver logger = logging.getLogger() class CustomTransformer(StreamTransformer): required_stream_modes = ("custom",) def __init__(self, scope: tuple[str, ...] = ()) -> None: super().__init__(scope) self.log = StreamChannel() def init(self) -> dict: return {"custom": self.log} def process(self, event: ProtocolEvent) -> bool: if event["method"] == "custom": self.log.push(event["params"]["data"]) return True class FashionAgentService: async def run_stream(self, request: FashionAgentRequest) -> AsyncGenerator[str, None]: """流式运行 agent - 使用 v3 projections""" config = {"configurable": {"thread_id": request.thread_id, "user_id": request.user_id}} async with AsyncPostgresSaver.from_conn_string("postgresql://postgres:Aidlab123123!@20.1.1.43:15432/myapp_prod") as checkpointer: await checkpointer.setup() agent = await build_main_graph(enable_thinking=request.enable_thinking, checkpointer=checkpointer) state = { "messages": [HumanMessage(content=request.message)], "input_images": request.input_images, "call_print": request.call_print, "call_logo": request.call_logo, "call_sketch": request.call_sketch, "call_design": request.call_design, "call_trending": request.call_trending, "call_explore": request.call_explore, "print_need_prompt_generation": request.print_need_prompt_generation, "sketch_need_prompt_generation": request.sketch_need_prompt_generation, "design_request_data": request.design_request_data, } stream = await agent.astream_events(state, config=config, version="v3", transformers=[ToolCallTransformer, CustomTransformer]) tool_names = {} filter_tool_name = ["design_tool"] async for event in stream: if event["method"] == "tools": data = event["params"]["data"] tool_call_id = data.get("tool_call_id") # 记录 tool_name if data.get("event") == "tool-started": tool_names[tool_call_id] = data.get("tool_name") # 通过 ID 查找 tool_name elif data.get("event") == "tool-finished": tool_name = tool_names.get(tool_call_id, "unknown") if tool_name in filter_tool_name: continue data["tool_name"] = tool_name if isinstance(data["output"], ToolMessage): data["output"] = json.loads(data["output"].content) response_event = {"event_type": "tool", "data": data} yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n" elif event["method"] == "custom": data = event["params"]["data"] response_event = {"event_type": "tool", "data": data} yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n" elif event["method"] == "messages": event_data = event["params"]["data"] data = event_data[0] if len(event_data) > 0 else {} # 提取元数据 (如果有的话) metadata = event_data[1] if len(event_data) > 1 else {} if not isinstance(data, dict): continue if metadata.get("langgraph_node") in {"gen_prompt", "generate_query"}: continue ev = data.get("event") if ev == "content-block-delta": block = data.get("delta") or {} if block.get("type") in ("text-delta", "reasoning-delta"): response_event = {"event_type": "messages", "data": {"event": ev} | block} yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n" elif ev in ("message-start", "content-block-start", "content-block-finish", "message-finish"): response_event = {"event_type": "messages", "data": {"event": ev} | data} yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n" response_event = {"event_type": "done"} yield f"data: {response_event}" if __name__ == "__main__": import asyncio async def test_stream(): """测试流式调用""" with open("app/service/fashion_agent/graph_node/design_graph/design_request_data.json", encoding="utf-8") as f: request_data = json.load(f) service = FashionAgentService() print("=" * 50) print("测试流式输出") print("=" * 50) request = FashionAgentRequest( thread_id="zhh", message="落日", # call_print=True, # input_images=["test/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png"], # print_need_prompt_generation=False, # call_sketch=True, # sketch_need_prompt_generation=False, # call_logo=True, call_explore=True, # call_design=True, # design_request_data=request_data, ) async for chunk in service.run_stream(request): print(chunk, end="") # 运行测试 asyncio.run(test_stream())