- 更新了“generate_image.py”以接受输入图像以增强图像生成。 - 修改了`pexels_search.py`和`unsplash_search.py`以将日志记录和上传路径从“explorer”更改为“explore”。 - 调整了“print_graph”和“sketch_graph”以提取最新的用户输入并处理输入图像以生成打印和草图图像。 - 重构“generate_print_tool”和“generate_sketch_tool”以接受输入图像。 - 更新了“main_agent.py”以包含状态中的输入图像并调整了图形构建过程。 - 增强了“service.py”来管理输入图像并改进了流媒体期间的事件处理。 - 更新了新软件包和版本的“pyproject.toml”和“uv.lock”中的依赖项。
146 lines
5.9 KiB
Python
146 lines
5.9 KiB
Python
import json
|
|
import logging
|
|
from langgraph.stream import ProtocolEvent, StreamChannel, StreamTransformer
|
|
from app.service.fashion_agent.main_agent import build_main_graph
|
|
from langgraph.prebuilt import ToolCallTransformer
|
|
from typing import AsyncGenerator
|
|
from langchain_core.messages import HumanMessage, ToolMessage
|
|
from app.schemas.fashion_agent import FashionAgentRequest
|
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class CustomTransformer(StreamTransformer):
|
|
required_stream_modes = ("custom",)
|
|
|
|
def __init__(self, scope: tuple[str, ...] = ()) -> None:
|
|
super().__init__(scope)
|
|
self.log = StreamChannel()
|
|
|
|
def init(self) -> dict:
|
|
return {"custom": self.log}
|
|
|
|
def process(self, event: ProtocolEvent) -> bool:
|
|
if event["method"] == "custom":
|
|
self.log.push(event["params"]["data"])
|
|
return True
|
|
|
|
|
|
class FashionAgentService:
|
|
|
|
async def run_stream(self, request: FashionAgentRequest) -> AsyncGenerator[str, None]:
|
|
"""流式运行 agent - 使用 v3 projections"""
|
|
|
|
config = {"configurable": {"thread_id": request.thread_id, "user_id": request.user_id}}
|
|
|
|
async with AsyncPostgresSaver.from_conn_string("postgresql://postgres:Aidlab123123!@20.1.1.43:15432/myapp_prod") as checkpointer:
|
|
await checkpointer.setup()
|
|
agent = await build_main_graph(enable_thinking=request.enable_thinking, checkpointer=checkpointer)
|
|
|
|
state = {
|
|
"messages": [HumanMessage(content=request.message)],
|
|
"input_images": request.input_images,
|
|
"call_print": request.call_print,
|
|
"call_logo": request.call_logo,
|
|
"call_sketch": request.call_sketch,
|
|
"call_design": request.call_design,
|
|
"call_trending": request.call_trending,
|
|
"call_explore": request.call_explore,
|
|
"print_need_prompt_generation": request.print_need_prompt_generation,
|
|
"sketch_need_prompt_generation": request.sketch_need_prompt_generation,
|
|
"design_request_data": request.design_request_data,
|
|
}
|
|
|
|
stream = await agent.astream_events(state, config=config, version="v3", transformers=[ToolCallTransformer, CustomTransformer])
|
|
|
|
tool_names = {}
|
|
filter_tool_name = ["design_tool"]
|
|
async for event in stream:
|
|
if event["method"] == "tools":
|
|
data = event["params"]["data"]
|
|
tool_call_id = data.get("tool_call_id")
|
|
|
|
# 记录 tool_name
|
|
if data.get("event") == "tool-started":
|
|
tool_names[tool_call_id] = data.get("tool_name")
|
|
|
|
# 通过 ID 查找 tool_name
|
|
elif data.get("event") == "tool-finished":
|
|
tool_name = tool_names.get(tool_call_id, "unknown")
|
|
|
|
if tool_name in filter_tool_name:
|
|
continue
|
|
|
|
data["tool_name"] = tool_name
|
|
|
|
if isinstance(data["output"], ToolMessage):
|
|
data["output"] = json.loads(data["output"].content)
|
|
|
|
response_event = {"event_type": "tool", "data": data}
|
|
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
|
|
|
elif event["method"] == "custom":
|
|
data = event["params"]["data"]
|
|
response_event = {"event_type": "tool", "data": data}
|
|
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
|
|
|
elif event["method"] == "messages":
|
|
event_data = event["params"]["data"]
|
|
data = event_data[0] if len(event_data) > 0 else {}
|
|
# 提取元数据 (如果有的话)
|
|
metadata = event_data[1] if len(event_data) > 1 else {}
|
|
if not isinstance(data, dict):
|
|
continue
|
|
if metadata.get("langgraph_node") in {"gen_prompt", "generate_query"}:
|
|
continue
|
|
|
|
ev = data.get("event")
|
|
|
|
if ev == "content-block-delta":
|
|
block = data.get("delta") or {}
|
|
if block.get("type") in ("text-delta", "reasoning-delta"):
|
|
response_event = {"event_type": "messages", "data": {"event": ev} | block}
|
|
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
|
|
|
elif ev in ("message-start", "content-block-start", "content-block-finish", "message-finish"):
|
|
response_event = {"event_type": "messages", "data": {"event": ev} | data}
|
|
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
|
|
|
response_event = {"event_type": "done"}
|
|
yield f"data: {response_event}"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
|
|
async def test_stream():
|
|
"""测试流式调用"""
|
|
|
|
with open("app/service/fashion_agent/graph_node/design_graph/design_request_data.json", encoding="utf-8") as f:
|
|
request_data = json.load(f)
|
|
|
|
service = FashionAgentService()
|
|
|
|
print("=" * 50)
|
|
print("测试流式输出")
|
|
print("=" * 50)
|
|
request = FashionAgentRequest(
|
|
thread_id="zhh",
|
|
message="落日",
|
|
# call_print=True,
|
|
# input_images=["test/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png"],
|
|
# print_need_prompt_generation=False,
|
|
# call_sketch=True,
|
|
# sketch_need_prompt_generation=False,
|
|
# call_logo=True,
|
|
call_explore=True,
|
|
# call_design=True,
|
|
# design_request_data=request_data,
|
|
)
|
|
async for chunk in service.run_stream(request):
|
|
print(chunk, end="")
|
|
|
|
# 运行测试
|
|
asyncio.run(test_stream())
|