aida agent (基础版)搭建完成
This commit is contained in:
132
app/service/fashion_agent/service.py
Normal file
132
app/service/fashion_agent/service.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from langgraph.stream import ProtocolEvent, StreamChannel, StreamTransformer
|
||||
from app.service.fashion_agent.main_agent import build_main_graph
|
||||
from langgraph.prebuilt import ToolCallTransformer
|
||||
from typing import AsyncGenerator, TypedDict
|
||||
from langchain_core.messages import HumanMessage
|
||||
from app.schemas.fashion_agent import FashionAgentRequest
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class CustomTransformer(StreamTransformer):
|
||||
required_stream_modes = ("custom",)
|
||||
|
||||
def __init__(self, scope: tuple[str, ...] = ()) -> None:
|
||||
super().__init__(scope)
|
||||
self.log = StreamChannel()
|
||||
|
||||
def init(self) -> dict:
|
||||
return {"custom": self.log}
|
||||
|
||||
def process(self, event: ProtocolEvent) -> bool:
|
||||
if event["method"] == "custom":
|
||||
self.log.push(event["params"]["data"])
|
||||
return True
|
||||
|
||||
|
||||
class FashionAgentService:
|
||||
|
||||
async def run_stream(self, request: FashionAgentRequest) -> AsyncGenerator[str, None]:
|
||||
"""流式运行 agent - 使用 v3 projections"""
|
||||
|
||||
config = {"configurable": {"user_id": request.user_id}}
|
||||
|
||||
agent = build_main_graph(enable_thinking=request.enable_thinking)
|
||||
state = {
|
||||
"messages": [HumanMessage(content=request.message)],
|
||||
"call_print": request.call_print,
|
||||
"call_logo": request.call_logo,
|
||||
"call_sketch": request.call_sketch,
|
||||
"call_design": request.call_design,
|
||||
"call_trending": request.call_trending,
|
||||
"call_explor": request.call_explor,
|
||||
"print_need_prompt_generation": request.print_need_prompt_generation,
|
||||
"sketch_need_prompt_generation": request.sketch_need_prompt_generation,
|
||||
"design_request_data": request.design_request_data,
|
||||
}
|
||||
|
||||
stream = await agent.astream_events(state, config=config, version="v3", transformers=[ToolCallTransformer, CustomTransformer])
|
||||
|
||||
tool_names = {}
|
||||
filter_tool_name = ["design_tool"]
|
||||
async for event in stream:
|
||||
if event["method"] == "tools":
|
||||
data = event["params"]["data"]
|
||||
tool_call_id = data.get("tool_call_id")
|
||||
|
||||
# 记录 tool_name
|
||||
if data.get("event") == "tool-started":
|
||||
tool_names[tool_call_id] = data.get("tool_name")
|
||||
|
||||
# 通过 ID 查找 tool_name
|
||||
elif data.get("event") == "tool-finished":
|
||||
tool_name = tool_names.get(tool_call_id, "unknown")
|
||||
|
||||
if tool_name in filter_tool_name:
|
||||
continue
|
||||
|
||||
data["tool_name"] = tool_name
|
||||
|
||||
response_event = {"event_type": "tool", "data": data}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif event["method"] == "custom":
|
||||
data = event["params"]["data"]
|
||||
response_event = {"event_type": "tool", "data": data}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif event["method"] == "messages":
|
||||
data = event["params"]["data"][0]
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
if data.get("event") != "content-block-delta":
|
||||
continue
|
||||
block = data.get("delta") or {}
|
||||
|
||||
if block.get("type") == "text-delta":
|
||||
response_event = {"event_type": "messages", "data": {"event": data["event"]} | block}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
elif block.get("type") == "reasoning-delta":
|
||||
response_event = {"event_type": "messages", "data": {"event": data["event"]} | block}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
pass
|
||||
# print(f"----------------{event}")
|
||||
response_event = {"event_type": "done"}
|
||||
yield f"data: {response_event}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def test_stream():
|
||||
"""测试流式调用"""
|
||||
|
||||
with open("app/service/fashion_agent/graph_node/design_graph/design_request_data.json", encoding="utf-8") as f:
|
||||
request_data = json.load(f)
|
||||
|
||||
service = FashionAgentService()
|
||||
|
||||
print("=" * 50)
|
||||
print("测试流式输出")
|
||||
print("=" * 50)
|
||||
request = FashionAgentRequest(
|
||||
message="生成一张草莓图案",
|
||||
call_print=True,
|
||||
# print_need_prompt_generation=False,
|
||||
# call_sketch=True,
|
||||
# sketch_need_prompt_generation=False,
|
||||
# call_logo=True,
|
||||
# call_explor=True,
|
||||
# call_design=True,
|
||||
# design_request_data=request_data,
|
||||
)
|
||||
async for chunk in service.run_stream(request):
|
||||
print(chunk, end="")
|
||||
|
||||
# 运行测试
|
||||
asyncio.run(test_stream())
|
||||
Reference in New Issue
Block a user