aida agent (基础版)搭建完成
This commit is contained in:
178
app/service/fashion_agent/graph_node/sketch_graph/graph.py
Normal file
178
app/service/fashion_agent/graph_node/sketch_graph/graph.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Annotated, Required, TypedDict
|
||||
from langchain_qwq import ChatQwen
|
||||
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from pydantic import BaseModel, Field
|
||||
from app.service.fashion_agent.init_llm import qwen_plus_llm
|
||||
|
||||
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
"""定义状态"""
|
||||
|
||||
|
||||
class SketchState(TypedDict):
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
input_text: str
|
||||
role: str = ""
|
||||
gender: str = ""
|
||||
style: str = ""
|
||||
sketch_need_prompt_generation: bool = False # 是否需要使用 prompt 生成节点
|
||||
|
||||
sketch_num: int = 1
|
||||
|
||||
sketch_prompts: list[str] = []
|
||||
sketch_img_urls: list[str] = []
|
||||
|
||||
|
||||
"""生成服装草图的提示词节点"""
|
||||
|
||||
|
||||
# 定义输出结构
|
||||
class SketchPrompt(BaseModel):
|
||||
"""生成的印花图像提示词"""
|
||||
|
||||
prompts: list[str] = Field(description="用于生成服装草图的详细提示词")
|
||||
|
||||
|
||||
def extract_input_node(state: SketchState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
def generate_sketch_prompt_node(state: SketchState) -> dict:
|
||||
"""根据用户输入生成服装草图的图像生成提示词"""
|
||||
structured_llm = qwen_plus_llm.with_structured_output(SketchPrompt)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=f"""你是一个专业的服装设计师。
|
||||
请根据用户输入,生成用于AI图像生成的服装草图提示词。
|
||||
|
||||
要求:
|
||||
1. 提示词必须包含:clean black and white line drawing only, pure white background, centered composition
|
||||
2. 提示词应该详细描述服装的廓形、结构、细节
|
||||
3. 提示词应该适合用于 Stable Diffusion 图像生成模型
|
||||
4. 提示词应该使用英文,因为图像生成模型对英文理解更好
|
||||
5. 草图风格必须是黑白线稿,不要添加颜色
|
||||
6. 提示词数量为 {state.get("sketch_num", 1)}
|
||||
"""),
|
||||
HumanMessage(content=state["input_text"]),
|
||||
]
|
||||
|
||||
result = structured_llm.invoke(messages)
|
||||
prompts = result.prompts
|
||||
|
||||
return {
|
||||
"sketch_prompts": prompts,
|
||||
}
|
||||
|
||||
|
||||
"""生成服装草图节点"""
|
||||
|
||||
|
||||
async def generate_sketch_img_node(state: SketchState) -> dict:
|
||||
"""根据生成的提示词,生成服装草图"""
|
||||
# 如果 sketch_need_prompt_generation=False 且 sketch_prompts 为空,使用模板生成 prompt
|
||||
# if not state.get("sketch_need_prompt_generation", False) and not state.get("sketch_prompts"):
|
||||
|
||||
# input_text = state.get("input_text", "")
|
||||
# prompts = [build_sketch_template_prompt(input_text)]
|
||||
# else:
|
||||
# prompts = state["sketch_prompts"] if state["sketch_prompts"] else [state["input_text"]]
|
||||
|
||||
# sketch_img_urls = []
|
||||
# for prompt in prompts:
|
||||
# image_url = await generate_sketch_tool.ainvoke({"prompt": prompt})
|
||||
# sketch_img_urls.append(image_url)
|
||||
|
||||
# result_text = f"服装草图生成完成,共生成 {len(sketch_img_urls)} 张图片:\n" + "\n".join(sketch_img_urls)
|
||||
# return {"sketch_img_urls": sketch_img_urls, "messages": [AIMessage(content=result_text)]}
|
||||
return {"messages": [AIMessage(content="hello")]}
|
||||
|
||||
|
||||
"""条件分支 判断是否需要生成 prompt"""
|
||||
|
||||
|
||||
def should_generate_prompt(state: SketchState) -> str:
|
||||
"""条件分支:判断是否需要生成 prompt"""
|
||||
if state.get("sketch_need_prompt_generation", False):
|
||||
return "gen_prompt"
|
||||
else:
|
||||
return "gen_sketch"
|
||||
|
||||
|
||||
def build_sketch_graph():
|
||||
workflow = StateGraph(SketchState)
|
||||
workflow.add_node("gen_sketch", generate_sketch_img_node)
|
||||
workflow.add_edge(START, "gen_sketch")
|
||||
workflow.add_edge("gen_sketch", END)
|
||||
graph = workflow.compile()
|
||||
return graph
|
||||
|
||||
# workflow = StateGraph(SketchState)
|
||||
# workflow.add_node("extract_input", extract_input_node)
|
||||
# workflow.add_node("gen_prompt", generate_sketch_prompt_node)
|
||||
# workflow.add_node("gen_sketch", generate_sketch_img_node)
|
||||
|
||||
# # 添加边
|
||||
# workflow.add_edge(START, "extract_input")
|
||||
# workflow.add_conditional_edges(
|
||||
# "extract_input",
|
||||
# should_generate_prompt,
|
||||
# {
|
||||
# "gen_prompt": "gen_prompt",
|
||||
# "gen_sketch": "gen_sketch",
|
||||
# },
|
||||
# )
|
||||
# workflow.add_edge("gen_prompt", "gen_sketch")
|
||||
# workflow.add_edge("gen_sketch", END)
|
||||
|
||||
# graph = workflow.compile()
|
||||
# return graph
|
||||
|
||||
|
||||
def build_sketch_template_prompt(input_text: str) -> str:
|
||||
"""构建 sketch prompt 模板"""
|
||||
return f"{input_text}, clean black and white line drawing only, pure white background, centered composition, fashion sketch style"
|
||||
|
||||
|
||||
async def main(test_input, sketch_need_prompt_generation=False):
|
||||
graph = build_sketch_graph()
|
||||
|
||||
# 如果不需要 LLM 生成 prompt,使用模板
|
||||
if not sketch_need_prompt_generation:
|
||||
sketch_prompts = [build_sketch_template_prompt(test_input)]
|
||||
else:
|
||||
sketch_prompts = []
|
||||
|
||||
result = await graph.ainvoke(
|
||||
{
|
||||
"input_text": test_input,
|
||||
"sketch_prompts": sketch_prompts,
|
||||
"sketch_need_prompt_generation": sketch_need_prompt_generation,
|
||||
"role": "",
|
||||
"gender": "",
|
||||
"style": "",
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试示例 1: 直接使用模板 prompt(默认)
|
||||
test_input = "dress"
|
||||
result = asyncio.run(main(test_input, sketch_need_prompt_generation=False))
|
||||
print("=== 使用模板 prompt ===")
|
||||
print(f"Result: {result}")
|
||||
|
||||
# # 测试示例 2: 使用 LLM 生成 prompt
|
||||
# test_input = "设计一条优雅的A字廓形连衣裙,V领设计,收腰,裙摆到膝盖,适合日常穿着"
|
||||
# result = asyncio.run(main(test_input, sketch_need_prompt_generation=True))
|
||||
# print("\n=== 使用 LLM 生成 prompt ===")
|
||||
# print(f"Result: {result}")
|
||||
33
app/service/fashion_agent/graph_node/sketch_graph/tools.py
Normal file
33
app/service/fashion_agent/graph_node/sketch_graph/tools.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
|
||||
from langchain.tools import tool
|
||||
from langsmith import uuid7
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.service.fashion_agent.graph_node.node_tools.generate_image import generate_image
|
||||
|
||||
|
||||
class GenerateImageToolInput(BaseModel):
|
||||
"""Input schema for the Generate Image Tool."""
|
||||
|
||||
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
|
||||
|
||||
|
||||
@tool(args_schema=GenerateImageToolInput)
|
||||
async def generate_sketch_tool(prompt: str) -> str:
|
||||
"""Generate an image based on the provided prompt."""
|
||||
|
||||
bucket_name = "fida-public-bucket"
|
||||
object_name = f"test/{uuid7()}.png"
|
||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
||||
return image_url
|
||||
|
||||
|
||||
async def run_test():
|
||||
result = await generate_sketch_tool.ainvoke({"prompt": "A cozy living room with warm lighting and natural textures."})
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = asyncio.run(run_test())
|
||||
print(result)
|
||||
Reference in New Issue
Block a user