aida agent (基础版)搭建完成

This commit is contained in:
zcr
2026-06-15 14:48:17 +08:00
parent b602c47fc9
commit dbbaa7503c
25 changed files with 1953 additions and 717 deletions

View File

@@ -0,0 +1,158 @@
import asyncio
import logging
from typing import Annotated, Required, TypedDict
from langchain_qwq import ChatQwen
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from app.service.fashion_agent.init_llm import qwen_plus_llm
from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool, test
logger = logging.getLogger()
"""定义状态"""
class PrintState(TypedDict):
messages: Required[Annotated[list[AnyMessage], add_messages]]
input_text: str
role: str = ""
gender: str = ""
style: str = ""
print_need_prompt_generation: bool = False # 是否需要使用 prompt 生成节点
print_num: int = 1
print_prompts: list[str] = []
print_img_urls: list[str] = []
"""生成印花图案的提示词节点"""
# 定义输出结构
class PrintPrompt(BaseModel):
"""生成的印花图像提示词"""
prompts: list[str] = Field(description="用于生成印花图案的详细提示词")
def extract_input_node(state: PrintState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
return {"input_text": input_text}
def generate_print_prompt_node(state: PrintState) -> dict:
"""根据用户输入生成印花图案的图像生成提示词"""
structured_llm = qwen_plus_llm.with_structured_output(PrintPrompt)
messages = [
SystemMessage(content=f"""你是一个专业的印花图案设计师。
请根据用户输入生成用于AI图像生成的印花图案提示词。
要求:
1. 提示词应该详细描述印花图案的样式、元素、颜色、布局
2. 提示词应该适合用于 Stable Diffusion 图像生成模型
3. 提示词应该使用英文,因为图像生成模型对英文理解更好
4. 提示词数量为 {state.get("print_num", 1)}
"""),
HumanMessage(content=state["input_text"]),
]
result = structured_llm.invoke(messages)
prompts = result.prompts
logger.info(f"[Print Graph] Generated print prompts: {prompts}")
return {
"print_prompts": prompts,
}
"""生成印花图案节点"""
async def generate_print_img_node(state: PrintState) -> dict:
"""根据生成的提示词,生成印花图案"""
# 如果 print_prompts 为空,使用 input_text 作为 prompt
if state.get("print_need_prompt_generation", False):
prompts = state["print_prompts"] if state["print_prompts"] else [state["input_text"]]
else:
input_text = state.get("input_text", "")
prompts = [input_text]
print_img_urls = []
for prompt in prompts:
image_url = await generate_print_tool.ainvoke({"prompt": prompt})
print_img_urls.append(image_url)
logger.info(f"[Print Graph] Generated print image URL: {image_url}")
return {"print_img_urls": print_img_urls}
"""条件分支 判断是否需要生成 prompt"""
def should_generate_prompt(state: PrintState) -> str:
"""条件分支:判断是否需要生成 prompt"""
logger.info(
f"[Print Graph] should_generate_prompt: print_need_prompt_generation={state.get('print_need_prompt_generation')}, print_prompts={state.get('print_prompts')}"
)
if state.get("print_need_prompt_generation", True):
return "gen_prompt"
else:
return "gen_print"
def build_print_graph():
workflow = StateGraph(PrintState)
workflow.add_node("extract_input", extract_input_node)
workflow.add_node("gen_prompt", generate_print_prompt_node)
workflow.add_node("gen_print", generate_print_img_node)
# 添加边
workflow.add_edge(START, "extract_input")
workflow.add_conditional_edges(
"extract_input",
should_generate_prompt,
{
"gen_prompt": "gen_prompt",
"gen_print": "gen_print",
},
)
workflow.add_edge("gen_prompt", "gen_print")
workflow.add_edge("gen_print", END)
graph = workflow.compile()
return graph
async def main(test_input, print_need_prompt_generation=True):
graph = build_print_graph()
result = await graph.ainvoke(
{
"input_text": test_input,
"print_prompts": [] if print_need_prompt_generation else [test_input],
"print_need_prompt_generation": print_need_prompt_generation,
"role": "",
"gender": "",
"style": "",
}
)
return result
if __name__ == "__main__":
# 测试示例 1: 需要 prompt 生成(默认)
test_input = "我想要一个优雅的花卉印花,适合用于连衣裙,颜色以粉色和白色为主"
result = asyncio.run(main(test_input, print_need_prompt_generation=True))
print("=== 需要 prompt 生成 ===")
print(f"Result: {result}")
# 测试示例 2: 直接使用用户提供的 prompt
user_prompt = "Elegant floral print pattern, pink and white colors, suitable for dress fabric, seamless tileable design"
result = asyncio.run(main(user_prompt, print_need_prompt_generation=False))
print("\n=== 直接使用 prompt ===")
print(f"Result: {result}")

View File

@@ -0,0 +1,39 @@
import asyncio
from langchain.tools import tool
from langsmith import uuid7
from pydantic import BaseModel, Field
from app.service.fashion_agent.graph_node.node_tools.generate_image import generate_image
class GenerateImageToolInput(BaseModel):
"""Input schema for the Generate Image Tool."""
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
@tool(args_schema=GenerateImageToolInput)
async def generate_print_tool(prompt: str) -> str:
"""Generate an image based on the provided prompt."""
bucket_name = "aida-users"
object_name = f"agent_generate_print/{uuid7()}.png"
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
return image_url
@tool
async def test(text: str):
"""测试工具函数,返回固定字符串"""
return text
async def run_test():
result = await generate_print_tool.ainvoke({"prompt": "A cozy living room with warm lighting and natural textures."})
return result
if __name__ == "__main__":
result = asyncio.run(run_test())
print(result)