重构图像生成和搜索工具;更新主代理来处理输入图像
- 更新了“generate_image.py”以接受输入图像以增强图像生成。 - 修改了`pexels_search.py`和`unsplash_search.py`以将日志记录和上传路径从“explorer”更改为“explore”。 - 调整了“print_graph”和“sketch_graph”以提取最新的用户输入并处理输入图像以生成打印和草图图像。 - 重构“generate_print_tool”和“generate_sketch_tool”以接受输入图像。 - 更新了“main_agent.py”以包含状态中的输入图像并调整了图形构建过程。 - 增强了“service.py”来管理输入图像并改进了流媒体期间的事件处理。 - 更新了新软件包和版本的“pyproject.toml”和“uv.lock”中的依赖项。
This commit is contained in:
@@ -9,8 +9,8 @@ from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
|
||||
from app.service.fashion_agent.init_llm import build_llm
|
||||
from app.service.fashion_agent.graph_node.explore_graph.tools import explore_tool
|
||||
from app.service.fashion_agent.init_llm import qwen_plus_llm
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
@@ -18,10 +18,10 @@ logger = logging.getLogger()
|
||||
"""定义状态"""
|
||||
|
||||
|
||||
class ExplorerState(TypedDict):
|
||||
class exploreState(TypedDict):
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
input_text: str
|
||||
search_query: str
|
||||
input_text: str = ""
|
||||
search_query: str = ""
|
||||
image_results: list[dict] # 每项包含 image_url 和 minio_path
|
||||
provider: str = "unsplash" # 图片源: "pexels" 或 "unsplash"
|
||||
|
||||
@@ -29,71 +29,70 @@ class ExplorerState(TypedDict):
|
||||
"""节点"""
|
||||
|
||||
|
||||
def extract_input_node(state: ExplorerState) -> dict:
|
||||
def extract_input_node(state: exploreState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
input_text = state["messages"][-1].content if state.get("messages") else input_text
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
"""搜索关键词"""
|
||||
|
||||
query: str = Field(description="用于搜索灵感图片的英文关键词,简洁有力")
|
||||
query: str = Field(description="用于搜索灵感图片的英文关键词,简洁有力")
|
||||
|
||||
|
||||
# TODO 要考虑搜索图片失败或者图片不存在的情况, 搜索不到 需要调整搜索词或者拆分搜索词,最终失败的话调用mood board生成工具生成, 保证绝对有图片
|
||||
async def generate_query_node(state: ExplorerState) -> dict:
|
||||
"""使用 LLM 分析用户输入,生成搜索关键词"""
|
||||
# TODO 要考虑搜索图片失败或者图片不存在的情况, 搜索不到 需要调整搜索词或者拆分搜索词,最终失败的话调用mood board生成工具生成, 保证绝对有图片
|
||||
async def generate_query_node(state: exploreState) -> dict:
|
||||
"""使用 LLM 分析用户输入,生成搜索关键词"""
|
||||
input_text = state["input_text"]
|
||||
logger.info(f"[Explorer] 用户输入: {input_text}")
|
||||
llm = build_llm()
|
||||
logger.info(f"[explore] 用户输入: {input_text}")
|
||||
|
||||
structured_llm = llm.with_structured_output(SearchQuery)
|
||||
structured_llm = qwen_plus_llm.with_structured_output(SearchQuery)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="""你是一个专业的服装设计师助手。
|
||||
根据用户输入,生成一个英文搜索关键词,用于在图片库中搜索服装设计灵感图片(moodboard)。
|
||||
SystemMessage(content="""你是专业服装设计师助手.
|
||||
根据用户中文需求,生成适合时尚灵感图(moodboard)图库搜索的英文关键词短句.
|
||||
|
||||
要求:
|
||||
1. 使用英文,简洁有力
|
||||
2. 适合搜索高质量的设计灵感图片
|
||||
严格输出规则:
|
||||
1. 必须返回标准JSON对象,**禁止输出任何额外文字,解释,思考,前言后语**;
|
||||
2. JSON 只包含一个字段 "query",值为简洁英文搜索词;
|
||||
3. 关键词简洁,适配高清时尚素材搜索;
|
||||
|
||||
例如:
|
||||
用户输入:"夏季连衣裙,清新风格"
|
||||
输出:summer dress fresh style"""),
|
||||
输出格式示例(仅允许输出如下JSON,不要加别的内容):
|
||||
{"query": "summer dress fresh style"}"""),
|
||||
HumanMessage(content=input_text),
|
||||
]
|
||||
|
||||
result = structured_llm.invoke(messages)
|
||||
logger.info(f"[Explorer] LLM 生成的搜索关键词: {result.query}")
|
||||
return {"search_query": result.query}
|
||||
search_query = result.query
|
||||
logger.info(f"[explore] LLM 生成的搜索关键词: {search_query}")
|
||||
return {"search_query": search_query}
|
||||
|
||||
|
||||
async def search_and_upload_node(state: ExplorerState, config: RunnableConfig) -> dict:
|
||||
async def search_and_upload_node(state: exploreState, config: RunnableConfig) -> dict:
|
||||
"""使用搜索关键词获取图片并上传到 minio"""
|
||||
query = state.get("search_query", "")
|
||||
user_id = state.get("user_id", "agent")
|
||||
provider = state.get("provider", "unsplash")
|
||||
|
||||
try:
|
||||
results = await explor_tool.ainvoke({"query": query, "per_page": 4, "user_id": user_id, "method": provider}, config=config)
|
||||
results = await explore_tool.ainvoke({"query": query, "per_page": 4, "user_id": user_id, "method": provider}, config=config)
|
||||
except Exception as e:
|
||||
logger.error(f"[Explorer] 搜索失败 '{query}': {e}")
|
||||
logger.error(f"[explore] 搜索失败 '{query}': {e}")
|
||||
results = []
|
||||
|
||||
return {"image_results": results}
|
||||
|
||||
|
||||
def summarize_node(state: ExplorerState) -> dict:
|
||||
def summarize_node(state: exploreState) -> dict:
|
||||
"""汇总结果"""
|
||||
input_text = state.get("input_text", "")
|
||||
query = state.get("search_query", "")
|
||||
results = state.get("image_results", [])
|
||||
|
||||
result_text = f"【灵感探索 Moodboard】\n\n"
|
||||
result_text += f"基于您的需求:「{input_text}」\n"
|
||||
result_text += f"搜索关键词:{query}\n\n"
|
||||
result_text += f"已为您找到 {len(results)} 张灵感图片:\n"
|
||||
result_text = f"[灵感探索 Moodboard]\n\n"
|
||||
result_text += f"基于您的需求: {input_text}\n\n"
|
||||
result_text += f"搜索关键词:{query}\n\n"
|
||||
result_text += f"已为您找到 {len(results)} 张灵感图片:\n"
|
||||
|
||||
for i, item in enumerate(results, 1):
|
||||
result_text += f" {i}. 原图: {item.get('image_url', '')}\n"
|
||||
@@ -105,9 +104,9 @@ def summarize_node(state: ExplorerState) -> dict:
|
||||
"""构建图"""
|
||||
|
||||
|
||||
def build_explorer_graph():
|
||||
def build_explore_graph():
|
||||
"""构建灵感探索图"""
|
||||
workflow = StateGraph(ExplorerState)
|
||||
workflow = StateGraph(exploreState)
|
||||
|
||||
workflow.add_node("extract_input", extract_input_node)
|
||||
workflow.add_node("generate_query", generate_query_node)
|
||||
@@ -126,10 +125,10 @@ def build_explorer_graph():
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def test():
|
||||
graph = build_explorer_graph()
|
||||
graph = build_explore_graph()
|
||||
result = await graph.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="夏季连衣裙,清新自然风格")],
|
||||
"messages": [HumanMessage(content="夏季连衣裙,清新自然风格")],
|
||||
"provider": "unsplash",
|
||||
}
|
||||
)
|
||||
@@ -15,7 +15,7 @@ class SearchInput(BaseModel):
|
||||
|
||||
|
||||
@tool(args_schema=SearchInput)
|
||||
async def explor_tool(
|
||||
async def explore_tool(
|
||||
query: str, per_page: int = 4, user_id: str = "agent", method: str = "unsplash", config: RunnableConfig = None
|
||||
) -> list[dict]:
|
||||
"""Search for fashion inspiration images on Unsplash and upload to minio. Returns a list of dicts with image_url and minio_path."""
|
||||
@@ -7,7 +7,7 @@ from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from pydantic import BaseModel, Field
|
||||
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
|
||||
from app.service.fashion_agent.init_llm import qwen_plus_llm
|
||||
from app.service.fashion_agent.init_llm import build_llm, qwen_plus_llm
|
||||
|
||||
"""初始化 LLM TODO 将 API Key 替换为环境变量或者配置文件中的值,避免在代码中硬编码敏感信息"""
|
||||
|
||||
@@ -17,7 +17,7 @@ from app.service.fashion_agent.init_llm import qwen_plus_llm
|
||||
|
||||
class LogoState(TypedDict):
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
input_text: str
|
||||
input_text: str = ""
|
||||
user_id: str = "agent"
|
||||
role: str = ""
|
||||
gender: str = ""
|
||||
@@ -35,14 +35,14 @@ class LogoState(TypedDict):
|
||||
|
||||
# 定义输出结构
|
||||
class LogoPrompt(BaseModel):
|
||||
"""生成的 Logo 图像提示词"""
|
||||
"""logo image generation diagram prompt words"""
|
||||
|
||||
prompts: list[str] = Field(description="用于生成 Logo 的详细提示词")
|
||||
prompts: list[str] = Field(description="Array of prompt words, simple English words")
|
||||
|
||||
|
||||
def extract_input_node(state: LogoState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
input_text = state["messages"][-1].content if state.get("messages") else state["input_text"]
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
@@ -51,18 +51,19 @@ def generate_logo_prompt_node(state: LogoState) -> dict:
|
||||
structured_llm = qwen_plus_llm.with_structured_output(LogoPrompt)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="""从用户输入中提取核心主题词,只输出一个简单的英文单词。
|
||||
SystemMessage(content="""从用户输入中提取核心主题词,用一个简单的英文单词作为 prompts 字段的值。
|
||||
例如:
|
||||
- "我想要一个猫咪图案" -> "cat"
|
||||
- "设计一个花朵" -> "flower"
|
||||
- "可爱的狗" -> "dog"
|
||||
只输出单词,不要其他内容。"""),
|
||||
- "我想要一个猫咪图案" -> {"prompts": ["cat"]}
|
||||
- "设计一个花朵" -> {"prompts": ["flower"]}
|
||||
- "可爱的狗" -> {"prompts": ["dog"]}
|
||||
请严格按照 JSON 格式输出,包含 prompts 字段。
|
||||
"""),
|
||||
HumanMessage(content=state["input_text"]),
|
||||
]
|
||||
|
||||
result = structured_llm.invoke(messages)
|
||||
prompts = result.prompts
|
||||
|
||||
print(result)
|
||||
return {
|
||||
"logo_prompts": prompts,
|
||||
}
|
||||
@@ -139,14 +140,14 @@ async def main(test_input, user_id="agent", need_prompt_generation=True):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试示例 1: 需要 prompt 生成(默认)- 简单关键词输入
|
||||
test_input = "我想要一个金毛图案"
|
||||
result = asyncio.run(main(test_input, need_prompt_generation=True))
|
||||
print("=== 需要 prompt 生成 ===")
|
||||
print(f"Result: {result}")
|
||||
|
||||
# 测试示例 2: 直接使用用户提供的 prompt
|
||||
user_prompt = "golden retriever"
|
||||
result = asyncio.run(main(user_prompt, need_prompt_generation=False))
|
||||
print("\n=== 直接使用 prompt ===")
|
||||
print(f"Result: {result}")
|
||||
async def test():
|
||||
graph = build_logo_graph()
|
||||
result = await graph.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="我想要一个金毛图案")],
|
||||
}
|
||||
)
|
||||
print(result["messages"][-1].content)
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
@@ -4,13 +4,14 @@ import httpx
|
||||
async def generate_image(
|
||||
bucket_name="fida-public-bucket",
|
||||
object_name=f"furniture/sketches/123456.png",
|
||||
input_images=[],
|
||||
prompt="Generate a modern minimalist dining chair made of light "
|
||||
"oak wood and white leather, with slim metal legs, photographed "
|
||||
"in a bright Scandinavian living room with natural sunlight, high detail, "
|
||||
"8k resolution.",
|
||||
):
|
||||
request_data = {
|
||||
"input_image_paths": [],
|
||||
"input_image_paths": input_images,
|
||||
"prompt": prompt,
|
||||
"bucket_name": bucket_name,
|
||||
"object_name": object_name,
|
||||
|
||||
@@ -63,10 +63,10 @@ async def search_photos(query: str, per_page: int = 4, user_id: str = "agent") -
|
||||
# 上传到 minio(使用线程池避免阻塞事件循环)
|
||||
file_name = f"{uuid7()}.jpg"
|
||||
loop = asyncio.get_event_loop()
|
||||
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
|
||||
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explore", file_name)
|
||||
results.append({"image_url": image_url, "minio_path": minio_url})
|
||||
logger.info(f"[Explorer] 上传成功: {minio_url}")
|
||||
logger.info(f"[explore] 上传成功: {minio_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Explorer] 上传失败: {e}")
|
||||
logger.error(f"[explore] 上传失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
@@ -63,11 +63,11 @@ async def get_random_photos(query: str, count: int = 4, user_id: str = "agent")
|
||||
# 上传到 minio(使用线程池避免阻塞事件循环)
|
||||
file_name = f"{uuid7()}.jpg"
|
||||
loop = asyncio.get_event_loop()
|
||||
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
|
||||
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explore", file_name)
|
||||
results.append({"image_url": image_url, "minio_path": minio_url})
|
||||
logger.info(f"[Explorer] 上传成功: {minio_url}")
|
||||
logger.info(f"[explore] 上传成功: {minio_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Explorer] 上传失败: {e}")
|
||||
logger.error(f"[explore] 上传失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class PrintPrompt(BaseModel):
|
||||
|
||||
def extract_input_node(state: PrintState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
input_text = state["messages"][-1].content if state.get("messages") else ""
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
@@ -49,13 +49,13 @@ def generate_print_prompt_node(state: PrintState) -> dict:
|
||||
structured_llm = qwen_plus_llm.with_structured_output(PrintPrompt)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=f"""你是一个专业的印花图案设计师。
|
||||
请根据用户输入,生成用于AI图像生成的印花图案提示词。
|
||||
SystemMessage(content=f"""你是一个专业的印花图案设计师.
|
||||
请根据用户输入,生成用于AI图像生成的印花图案提示词.
|
||||
|
||||
要求:
|
||||
1. 提示词应该详细描述印花图案的样式、元素、颜色、布局
|
||||
要求:
|
||||
1. 提示词应该详细描述印花图案的样式,元素,颜色,布局
|
||||
2. 提示词应该适合用于 Stable Diffusion 图像生成模型
|
||||
3. 提示词应该使用英文,因为图像生成模型对英文理解更好
|
||||
3. 提示词应该使用英文,因为图像生成模型对英文理解更好
|
||||
4. 提示词数量为 {state.get("print_num", 1)}
|
||||
"""),
|
||||
HumanMessage(content=state["input_text"]),
|
||||
@@ -73,17 +73,19 @@ def generate_print_prompt_node(state: PrintState) -> dict:
|
||||
|
||||
|
||||
async def generate_print_img_node(state: PrintState) -> dict:
|
||||
"""根据生成的提示词,生成印花图案"""
|
||||
# 如果 print_prompts 为空,使用 input_text 作为 prompt
|
||||
"""根据生成的提示词,生成印花图案"""
|
||||
# 如果 print_prompts 为空,使用 input_text 作为 prompt
|
||||
if state.get("print_need_prompt_generation", False):
|
||||
prompts = state["print_prompts"] if state["print_prompts"] else [state["input_text"]]
|
||||
else:
|
||||
input_text = state.get("input_text", "")
|
||||
prompts = [input_text]
|
||||
|
||||
input_images = state.get("input_images", [])
|
||||
|
||||
print_img_urls = []
|
||||
for prompt in prompts:
|
||||
image_url = await generate_print_tool.ainvoke({"prompt": prompt})
|
||||
image_url = await generate_print_tool.ainvoke({"prompt": prompt, "input_images": input_images})
|
||||
print_img_urls.append(image_url)
|
||||
logger.info(f"[Print Graph] Generated print image URL: {image_url}")
|
||||
|
||||
@@ -94,7 +96,7 @@ async def generate_print_img_node(state: PrintState) -> dict:
|
||||
|
||||
|
||||
def should_generate_prompt(state: PrintState) -> str:
|
||||
"""条件分支:判断是否需要生成 prompt"""
|
||||
"""条件分支:判断是否需要生成 prompt"""
|
||||
|
||||
logger.info(
|
||||
f"[Print Graph] should_generate_prompt: print_need_prompt_generation={state.get('print_need_prompt_generation')}, print_prompts={state.get('print_prompts')}"
|
||||
@@ -106,7 +108,6 @@ def should_generate_prompt(state: PrintState) -> str:
|
||||
|
||||
|
||||
def build_print_graph():
|
||||
|
||||
workflow = StateGraph(PrintState)
|
||||
workflow.add_node("extract_input", extract_input_node)
|
||||
workflow.add_node("gen_prompt", generate_print_prompt_node)
|
||||
@@ -145,8 +146,8 @@ async def main(test_input, print_need_prompt_generation=True):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试示例 1: 需要 prompt 生成(默认)
|
||||
test_input = "我想要一个优雅的花卉印花,适合用于连衣裙,颜色以粉色和白色为主"
|
||||
# 测试示例 1: 需要 prompt 生成(默认)
|
||||
test_input = "我想要一个优雅的花卉印花,适合用于连衣裙,颜色以粉色和白色为主"
|
||||
result = asyncio.run(main(test_input, print_need_prompt_generation=True))
|
||||
print("=== 需要 prompt 生成 ===")
|
||||
print(f"Result: {result}")
|
||||
|
||||
@@ -11,15 +11,16 @@ class GenerateImageToolInput(BaseModel):
|
||||
"""Input schema for the Generate Image Tool."""
|
||||
|
||||
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
|
||||
input_images: list[str] = Field(default=[], description="Input images for the generation.")
|
||||
|
||||
|
||||
@tool(args_schema=GenerateImageToolInput)
|
||||
async def generate_print_tool(prompt: str) -> str:
|
||||
async def generate_print_tool(prompt: str, input_images: list[str]) -> str:
|
||||
"""Generate an image based on the provided prompt."""
|
||||
|
||||
bucket_name = "aida-users"
|
||||
object_name = f"agent_generate_print/{uuid7()}.png"
|
||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name, input_images=input_images)
|
||||
return [image_url]
|
||||
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class SketchPrompt(BaseModel):
|
||||
|
||||
def extract_input_node(state: SketchState) -> dict:
|
||||
"""从 messages 中提取用户输入"""
|
||||
input_text = state["messages"][0].content if state.get("messages") else ""
|
||||
input_text = state["messages"][-1].content if state.get("messages") else ""
|
||||
return {"input_text": input_text}
|
||||
|
||||
|
||||
@@ -80,15 +80,16 @@ async def generate_sketch_img_node(state: SketchState) -> dict:
|
||||
"""根据生成的提示词,生成服装草图"""
|
||||
# 如果 sketch_need_prompt_generation=False 且 sketch_prompts 为空,使用模板生成 prompt
|
||||
if not state.get("sketch_need_prompt_generation", False) and not state.get("sketch_prompts"):
|
||||
|
||||
input_text = state.get("input_text", "")
|
||||
prompts = [build_sketch_template_prompt(input_text)]
|
||||
else:
|
||||
prompts = state["sketch_prompts"] if state["sketch_prompts"] else [state["input_text"]]
|
||||
|
||||
input_images = state.get("input_images", [])
|
||||
|
||||
sketch_img_urls = []
|
||||
for prompt in prompts:
|
||||
image_url = await generate_sketch_tool.ainvoke({"prompt": prompt})
|
||||
image_url = await generate_sketch_tool.ainvoke({"prompt": prompt, "input_images": input_images})
|
||||
sketch_img_urls.append(image_url)
|
||||
|
||||
return {"sketch_img_urls": sketch_img_urls}
|
||||
@@ -107,33 +108,26 @@ def should_generate_prompt(state: SketchState) -> str:
|
||||
|
||||
def build_sketch_graph():
|
||||
workflow = StateGraph(SketchState)
|
||||
workflow.add_node("extract_input", extract_input_node)
|
||||
workflow.add_node("gen_prompt", generate_sketch_prompt_node)
|
||||
workflow.add_node("gen_sketch", generate_sketch_img_node)
|
||||
workflow.add_edge(START, "gen_sketch")
|
||||
|
||||
# 添加边
|
||||
workflow.add_edge(START, "extract_input")
|
||||
workflow.add_conditional_edges(
|
||||
"extract_input",
|
||||
should_generate_prompt,
|
||||
{
|
||||
"gen_prompt": "gen_prompt",
|
||||
"gen_sketch": "gen_sketch",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("gen_prompt", "gen_sketch")
|
||||
workflow.add_edge("gen_sketch", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
return graph
|
||||
|
||||
# workflow = StateGraph(SketchState)
|
||||
# workflow.add_node("extract_input", extract_input_node)
|
||||
# workflow.add_node("gen_prompt", generate_sketch_prompt_node)
|
||||
# workflow.add_node("gen_sketch", generate_sketch_img_node)
|
||||
|
||||
# # 添加边
|
||||
# workflow.add_edge(START, "extract_input")
|
||||
# workflow.add_conditional_edges(
|
||||
# "extract_input",
|
||||
# should_generate_prompt,
|
||||
# {
|
||||
# "gen_prompt": "gen_prompt",
|
||||
# "gen_sketch": "gen_sketch",
|
||||
# },
|
||||
# )
|
||||
# workflow.add_edge("gen_prompt", "gen_sketch")
|
||||
# workflow.add_edge("gen_sketch", END)
|
||||
|
||||
# graph = workflow.compile()
|
||||
# return graph
|
||||
|
||||
|
||||
def build_sketch_template_prompt(input_text: str) -> str:
|
||||
"""构建 sketch prompt 模板"""
|
||||
|
||||
@@ -11,15 +11,16 @@ class GenerateImageToolInput(BaseModel):
|
||||
"""Input schema for the Generate Image Tool."""
|
||||
|
||||
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
|
||||
input_images: list[str] = Field(default=[], description="Input images for the generation.")
|
||||
|
||||
|
||||
@tool(args_schema=GenerateImageToolInput)
|
||||
async def generate_sketch_tool(prompt: str) -> str:
|
||||
async def generate_sketch_tool(prompt: str, input_images: list[str]) -> str:
|
||||
"""Generate an image based on the provided prompt."""
|
||||
|
||||
bucket_name = "fida-public-bucket"
|
||||
object_name = f"test/{uuid7()}.png"
|
||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
|
||||
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name, input_images=input_images)
|
||||
return [image_url]
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Required, TypedDict
|
||||
from deepagents import CompiledSubAgent, create_deep_agent
|
||||
from langchain.agents import create_agent
|
||||
from langchain.tools import tool
|
||||
from langchain_core.messages import AnyMessage, HumanMessage
|
||||
from langchain_qwq import ChatQwen
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from app.service.fashion_agent.graph_node.design_graph.graph import build_design_graph
|
||||
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool
|
||||
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
|
||||
from app.service.fashion_agent.graph_node.explore_graph.tools import explore_tool
|
||||
from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph
|
||||
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
|
||||
from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph
|
||||
@@ -18,7 +13,7 @@ from app.service.fashion_agent.graph_node.print_graph.tools import generate_prin
|
||||
from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph
|
||||
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
|
||||
from app.service.fashion_agent.graph_node.trending_graph.trending_graph import build_trending_graph
|
||||
from app.service.fashion_agent.graph_node.explorer_graph.graph import build_explorer_graph
|
||||
from app.service.fashion_agent.graph_node.explore_graph.graph import build_explore_graph
|
||||
from app.service.fashion_agent.init_llm import build_llm
|
||||
|
||||
print_graph = build_print_graph()
|
||||
@@ -26,13 +21,16 @@ logo_graph = build_logo_graph()
|
||||
sketch_graph = build_sketch_graph()
|
||||
design_graph = build_design_graph()
|
||||
trending_graph = build_trending_graph()
|
||||
explorer_graph = build_explorer_graph()
|
||||
explore_graph = build_explore_graph()
|
||||
|
||||
|
||||
class MainState(TypedDict):
|
||||
# 消息
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
|
||||
# 上传图片
|
||||
input_images: list[str] = []
|
||||
|
||||
# 模块控制
|
||||
call_design: bool = False
|
||||
call_print: bool = False
|
||||
@@ -40,7 +38,7 @@ class MainState(TypedDict):
|
||||
call_sketch: bool = False
|
||||
call_design: bool = False
|
||||
call_trending: bool = False
|
||||
call_explor: bool = False
|
||||
call_explore: bool = False
|
||||
|
||||
# design参数
|
||||
design_request_data: dict = {}
|
||||
@@ -58,7 +56,7 @@ class MainState(TypedDict):
|
||||
print_img_urls: list[str] = []
|
||||
|
||||
|
||||
tools = [explor_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
|
||||
tools = [explore_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
|
||||
|
||||
|
||||
def route_node(state: MainState) -> str:
|
||||
@@ -73,12 +71,12 @@ def route_node(state: MainState) -> str:
|
||||
return "direct_design"
|
||||
if state.get("call_trending"):
|
||||
return "direct_trending"
|
||||
if state.get("call_explor"):
|
||||
return "direct_explor"
|
||||
if state.get("call_explore"):
|
||||
return "direct_explore"
|
||||
return "llm_agent"
|
||||
|
||||
|
||||
def build_main_graph(enable_thinking: bool = False):
|
||||
async def build_main_graph(enable_thinking: bool = False, checkpointer=None):
|
||||
llm = build_llm(enable_thinking=enable_thinking)
|
||||
chat_agent = create_agent(
|
||||
model=llm, tools=tools, state_schema=MainState, system_prompt="你是一个专业的服装设计助手。根据用户需求,调用合适的工具完成任务."
|
||||
@@ -94,7 +92,7 @@ def build_main_graph(enable_thinking: bool = False):
|
||||
workflow.add_node("direct_sketch", sketch_graph)
|
||||
workflow.add_node("direct_design", design_graph)
|
||||
workflow.add_node("direct_trending", trending_graph)
|
||||
workflow.add_node("direct_explor", explorer_graph)
|
||||
workflow.add_node("direct_explore", explore_graph)
|
||||
|
||||
# 条件分支
|
||||
workflow.add_conditional_edges(
|
||||
@@ -107,7 +105,7 @@ def build_main_graph(enable_thinking: bool = False):
|
||||
"direct_sketch": "direct_sketch",
|
||||
"direct_design": "direct_design",
|
||||
"direct_trending": "direct_trending",
|
||||
"direct_explor": "direct_explor",
|
||||
"direct_explore": "direct_explore",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -118,27 +116,7 @@ def build_main_graph(enable_thinking: bool = False):
|
||||
workflow.add_edge("direct_sketch", END)
|
||||
workflow.add_edge("direct_design", END)
|
||||
workflow.add_edge("direct_trending", END)
|
||||
workflow.add_edge("direct_explor", END)
|
||||
workflow.add_edge("direct_explore", END)
|
||||
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
agent = build_main_graph()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def test_direct():
|
||||
# 直接调用 sketch,跳过 LLM
|
||||
result = await agent.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="我想设计衬衫,带有猫咪的图案")],
|
||||
"call_sketch": True,
|
||||
"sketch_need_prompt_generation": False,
|
||||
}
|
||||
)
|
||||
print("=== 直接调用 sketch ===")
|
||||
print(result["messages"][-1].content)
|
||||
|
||||
asyncio.run(test_direct())
|
||||
graph = workflow.compile(checkpointer=checkpointer)
|
||||
return graph
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from langgraph.stream import ProtocolEvent, StreamChannel, StreamTransformer
|
||||
from app.service.fashion_agent.main_agent import build_main_graph
|
||||
from langgraph.prebuilt import ToolCallTransformer
|
||||
from typing import AsyncGenerator, TypedDict
|
||||
from typing import AsyncGenerator
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
from app.schemas.fashion_agent import FashionAgentRequest
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
@@ -33,73 +32,83 @@ class FashionAgentService:
|
||||
async def run_stream(self, request: FashionAgentRequest) -> AsyncGenerator[str, None]:
|
||||
"""流式运行 agent - 使用 v3 projections"""
|
||||
|
||||
config = {"configurable": {"user_id": request.user_id}}
|
||||
config = {"configurable": {"thread_id": request.thread_id, "user_id": request.user_id}}
|
||||
|
||||
agent = build_main_graph(enable_thinking=request.enable_thinking)
|
||||
state = {
|
||||
"messages": [HumanMessage(content=request.message)],
|
||||
"call_print": request.call_print,
|
||||
"call_logo": request.call_logo,
|
||||
"call_sketch": request.call_sketch,
|
||||
"call_design": request.call_design,
|
||||
"call_trending": request.call_trending,
|
||||
"call_explor": request.call_explor,
|
||||
"print_need_prompt_generation": request.print_need_prompt_generation,
|
||||
"sketch_need_prompt_generation": request.sketch_need_prompt_generation,
|
||||
"design_request_data": request.design_request_data,
|
||||
}
|
||||
async with AsyncPostgresSaver.from_conn_string("postgresql://postgres:Aidlab123123!@20.1.1.43:15432/myapp_prod") as checkpointer:
|
||||
await checkpointer.setup()
|
||||
agent = await build_main_graph(enable_thinking=request.enable_thinking, checkpointer=checkpointer)
|
||||
|
||||
stream = await agent.astream_events(state, config=config, version="v3", transformers=[ToolCallTransformer, CustomTransformer])
|
||||
state = {
|
||||
"messages": [HumanMessage(content=request.message)],
|
||||
"input_images": request.input_images,
|
||||
"call_print": request.call_print,
|
||||
"call_logo": request.call_logo,
|
||||
"call_sketch": request.call_sketch,
|
||||
"call_design": request.call_design,
|
||||
"call_trending": request.call_trending,
|
||||
"call_explore": request.call_explore,
|
||||
"print_need_prompt_generation": request.print_need_prompt_generation,
|
||||
"sketch_need_prompt_generation": request.sketch_need_prompt_generation,
|
||||
"design_request_data": request.design_request_data,
|
||||
}
|
||||
|
||||
tool_names = {}
|
||||
filter_tool_name = ["design_tool"]
|
||||
async for event in stream:
|
||||
if event["method"] == "tools":
|
||||
data = event["params"]["data"]
|
||||
tool_call_id = data.get("tool_call_id")
|
||||
stream = await agent.astream_events(state, config=config, version="v3", transformers=[ToolCallTransformer, CustomTransformer])
|
||||
|
||||
# 记录 tool_name
|
||||
if data.get("event") == "tool-started":
|
||||
tool_names[tool_call_id] = data.get("tool_name")
|
||||
tool_names = {}
|
||||
filter_tool_name = ["design_tool"]
|
||||
async for event in stream:
|
||||
if event["method"] == "tools":
|
||||
data = event["params"]["data"]
|
||||
tool_call_id = data.get("tool_call_id")
|
||||
|
||||
# 通过 ID 查找 tool_name
|
||||
elif data.get("event") == "tool-finished":
|
||||
tool_name = tool_names.get(tool_call_id, "unknown")
|
||||
# 记录 tool_name
|
||||
if data.get("event") == "tool-started":
|
||||
tool_names[tool_call_id] = data.get("tool_name")
|
||||
|
||||
if tool_name in filter_tool_name:
|
||||
# 通过 ID 查找 tool_name
|
||||
elif data.get("event") == "tool-finished":
|
||||
tool_name = tool_names.get(tool_call_id, "unknown")
|
||||
|
||||
if tool_name in filter_tool_name:
|
||||
continue
|
||||
|
||||
data["tool_name"] = tool_name
|
||||
|
||||
if isinstance(data["output"], ToolMessage):
|
||||
data["output"] = json.loads(data["output"].content)
|
||||
|
||||
response_event = {"event_type": "tool", "data": data}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif event["method"] == "custom":
|
||||
data = event["params"]["data"]
|
||||
response_event = {"event_type": "tool", "data": data}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif event["method"] == "messages":
|
||||
event_data = event["params"]["data"]
|
||||
data = event_data[0] if len(event_data) > 0 else {}
|
||||
# 提取元数据 (如果有的话)
|
||||
metadata = event_data[1] if len(event_data) > 1 else {}
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
if metadata.get("langgraph_node") in {"gen_prompt", "generate_query"}:
|
||||
continue
|
||||
|
||||
data["tool_name"] = tool_name
|
||||
ev = data.get("event")
|
||||
|
||||
if isinstance(data["output"], ToolMessage):
|
||||
data["output"] = json.loads(data["output"].content)
|
||||
response_event = {"event_type": "tool", "data": data}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
if ev == "content-block-delta":
|
||||
block = data.get("delta") or {}
|
||||
if block.get("type") in ("text-delta", "reasoning-delta"):
|
||||
response_event = {"event_type": "messages", "data": {"event": ev} | block}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif event["method"] == "custom":
|
||||
data = event["params"]["data"]
|
||||
response_event = {"event_type": "tool", "data": data}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
elif ev in ("message-start", "content-block-start", "content-block-finish", "message-finish"):
|
||||
response_event = {"event_type": "messages", "data": {"event": ev} | data}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
|
||||
elif event["method"] == "messages":
|
||||
data = event["params"]["data"][0]
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
if data.get("event") != "content-block-delta":
|
||||
continue
|
||||
block = data.get("delta") or {}
|
||||
|
||||
if block.get("type") == "text-delta":
|
||||
response_event = {"event_type": "messages", "data": {"event": data["event"]} | block}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
elif block.get("type") == "reasoning-delta":
|
||||
response_event = {"event_type": "messages", "data": {"event": data["event"]} | block}
|
||||
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
pass
|
||||
# print(f"----------------{event}")
|
||||
response_event = {"event_type": "done"}
|
||||
yield f"data: {response_event}"
|
||||
response_event = {"event_type": "done"}
|
||||
yield f"data: {response_event}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -117,13 +126,15 @@ if __name__ == "__main__":
|
||||
print("测试流式输出")
|
||||
print("=" * 50)
|
||||
request = FashionAgentRequest(
|
||||
message="生成一张草莓图案",
|
||||
call_print=True,
|
||||
thread_id="zhh",
|
||||
message="落日",
|
||||
# call_print=True,
|
||||
# input_images=["test/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png"],
|
||||
# print_need_prompt_generation=False,
|
||||
# call_sketch=True,
|
||||
# sketch_need_prompt_generation=False,
|
||||
# call_logo=True,
|
||||
# call_explor=True,
|
||||
call_explore=True,
|
||||
# call_design=True,
|
||||
# design_request_data=request_data,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user